pytorch/torch/overrides.py
Edward Yang aacc722aec Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760

See https://github.com/pytorch/pytorch/issues/59049

There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.

**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.

**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with  then newly added `check_has_torch_dispatch`.

**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.

**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.

**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.

**Known limitations.**

* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision:
D29017912
D29017912

Test Plan: Imported from OSS

Reviewed By: bdhirsh

Pulled By: ezyang

fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 11:50:32 -07:00

1504 lines
80 KiB
Python

"""
Python implementation of ``__torch_function__``
While most of the torch API and handling for ``__torch_function__`` happens
at the C++ level, some of the torch API is written in Python so we need
python-level handling for ``__torch_function__`` overrides as well. The main
developer-facing functionality in this file are handle_torch_function and
has_torch_function. See torch/functional.py and test/test_overrides.py
for usage examples.
Note
----
heavily inspired by NumPy's ``__array_function__`` (see:
https://github.com/pytorch/pytorch/issues/24015 and
https://www.numpy.org/neps/nep-0018-array-function-protocol.html
)
If changing this file in a way that can affect ``__torch_function__`` overhead,
please report the benchmarks in ``benchmarks/overrides_benchmark``. See the
instructions in the ``README.md`` in that directory.
"""
import __future__
import collections
import functools
import types
from typing import Dict, Set, List, Any, Callable, Iterable, Type
import torch
from torch._C import (
_has_torch_function, _has_torch_function_unary,
_has_torch_function_variadic, _add_docstr)
__all__ = [
"get_ignored_functions",
"get_overridable_functions",
"get_testing_overrides",
"handle_torch_function",
"has_torch_function",
"is_tensor_like",
"is_tensor_method_or_property",
"wrap_torch_function",
]
@functools.lru_cache(None)
def get_ignored_functions() -> Set[Callable]:
"""
Return public functions that cannot be overridden by ``__torch_function__``.
Returns
-------
Set[Callable]
A tuple of functions that are publicly available in the torch API but cannot
be overridden with ``__torch_function__``. Mostly this is because none of the
arguments of these functions are tensors or tensor-likes.
Examples
--------
>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
True
>>> torch.add in torch.overrides.get_ignored_functions()
False
"""
Tensor = torch.Tensor
return {
torch.typename,
torch.is_tensor,
torch.is_storage,
torch.set_default_tensor_type,
torch.set_rng_state,
torch.get_rng_state,
torch.manual_seed,
torch.initial_seed,
torch.seed,
torch.save,
torch.load,
torch.set_printoptions,
torch.fork,
torch.get_default_dtype,
torch.get_num_interop_threads,
torch.get_num_threads,
torch.init_num_threads,
torch.import_ir_module,
torch.import_ir_module_from_buffer,
torch.is_anomaly_enabled,
torch.is_grad_enabled,
torch.merge_type_from_type_comment,
torch.parse_ir,
torch.parse_schema,
torch.parse_type_comment,
torch.set_anomaly_enabled,
torch.set_flush_denormal,
torch.set_num_interop_threads,
torch.set_num_threads,
torch.wait,
torch.as_tensor,
torch.from_numpy,
torch.get_device,
torch.tensor,
torch.default_generator,
torch.has_cuda,
torch.has_cudnn,
torch.has_lapack,
torch.device,
torch.dtype,
torch.finfo,
torch.has_mkl,
torch.has_mkldnn,
torch.has_openmp,
torch.iinfo,
torch.memory_format,
torch.qscheme,
torch.set_grad_enabled,
torch.no_grad,
torch.enable_grad,
torch.inference_mode,
torch.is_inference_mode_enabled,
torch.layout,
torch.align_tensors,
torch.arange,
torch.as_strided,
torch.bartlett_window,
torch.blackman_window,
torch.broadcast_shapes,
torch.can_cast,
torch.cudnn_affine_grid_generator,
torch.cudnn_batch_norm,
torch.cudnn_convolution,
torch.cudnn_convolution_transpose,
torch.cudnn_convolution_relu,
torch.cudnn_convolution_add_relu,
torch.cudnn_grid_sampler,
torch.cudnn_is_acceptable,
torch.empty,
torch.empty_strided,
torch.empty_quantized,
torch.eye,
torch.fft.fftfreq,
torch.fft.rfftfreq,
torch.from_file,
torch.full,
torch.hamming_window,
torch.hann_window,
torch.kaiser_window,
torch.linspace,
torch.logspace,
torch.mkldnn_adaptive_avg_pool2d,
torch.mkldnn_convolution,
torch.mkldnn_convolution_backward_weights,
torch.mkldnn_max_pool2d,
torch.mkldnn_max_pool3d,
torch.mkldnn_linear_backward_weights,
torch.normal,
torch.ones,
torch.promote_types,
torch.rand,
torch.randn,
torch.randint,
torch.randperm,
torch.range,
torch.result_type,
torch.scalar_tensor,
torch.sparse_coo_tensor,
torch.sparse_csr_tensor,
torch.tril_indices,
torch.triu_indices,
torch.vander,
torch.zeros,
torch._jit_internal.boolean_dispatch,
torch.nn.functional.assert_int_or_pair,
torch.nn.functional.upsample,
torch.nn.functional.upsample_bilinear,
torch.nn.functional.upsample_nearest,
torch.nn.functional.has_torch_function,
torch.nn.functional.has_torch_function_unary,
torch.nn.functional.has_torch_function_variadic,
torch.nn.functional.handle_torch_function,
torch.nn.functional.sigmoid,
torch.nn.functional.hardsigmoid,
torch.nn.functional.tanh,
has_torch_function,
handle_torch_function,
torch.set_autocast_enabled,
torch.is_autocast_enabled,
torch.clear_autocast_cache,
torch.set_autocast_cpu_enabled,
torch.is_autocast_cpu_enabled,
torch.set_autocast_cpu_dtype,
torch.get_autocast_cpu_dtype,
torch.autocast_increment_nesting,
torch.autocast_decrement_nesting,
torch.nn.functional.hardswish,
torch.is_vulkan_available,
torch.is_deterministic,
torch.are_deterministic_algorithms_enabled,
torch.use_deterministic_algorithms,
torch.set_deterministic,
torch.unify_type_list,
torch.is_warn_always_enabled,
torch.set_warn_always,
torch.vitals_enabled,
torch.set_vital,
Tensor.__delitem__,
Tensor.__dir__,
Tensor.__getattribute__,
Tensor.__init__,
Tensor.__iter__,
Tensor.__init_subclass__,
Tensor.__delattr__,
Tensor.__setattr__,
Tensor.__torch_function__,
Tensor.__new__,
Tensor.__class__,
Tensor.__subclasshook__,
Tensor.as_subclass,
Tensor.reinforce,
Tensor.new,
Tensor.new_tensor,
Tensor.new_empty,
Tensor.new_empty_strided,
Tensor.new_zeros,
Tensor.new_ones,
Tensor.new_full,
Tensor._make_subclass,
Tensor.stride,
Tensor.unflatten,
Tensor.to_sparse_csr,
Tensor._reduce_ex_internal,
Tensor._fix_weakref,
Tensor._python_dispatch.__get__,
Tensor._conj,
Tensor._conj_physical,
}
@functools.lru_cache(None)
def get_default_nowrap_functions() -> Set[Callable]:
"""
Return public functions that do not wrap in a subclass when invoked by
the default ``Tensor.__torch_function__`` that preserves subclasses. Typically,
these functions represent field accesses (i.e., retrieving a Tensor that
is stored somewhere on the Tensor) as opposed to computation. Users of
these functions expect object identity to be preserved over multiple accesses
(e.g., ``a.grad is a.grad``) which cannot be upheld if we're wrapping on
the fly every time (furthermore, the tensor stored here might already be
the subclass, in which case wrapping really ought not to happen).
Not ALL property accessors have this property; for example ``Tensor.T`` actually
just creates a new transposed tensor on the fly, and so we SHOULD interpose on
these calls (you need to check the implementation of the function to see if
this is the case or not). Additionally, if a property accessor doesn't return a Tensor,
it doesn't have to be on this list (though it is harmless if it is).
"""
Tensor = torch.Tensor
return {
Tensor._base.__get__,
Tensor.grad.__get__,
Tensor._grad.__get__,
}
@functools.lru_cache(None)
def get_testing_overrides() -> Dict[Callable, Callable]:
"""Return a dict containing dummy overrides for all overridable functions
Returns
-------
Dict[Callable, Callable]
A dictionary that maps overridable functions in the PyTorch API to
lambda functions that have the same signature as the real function
and unconditionally return -1. These lambda functions are useful
for testing API coverage for a type that defines ``__torch_function__``.
Examples
--------
>>> import inspect
>>> my_add = torch.overrides.get_testing_overrides()[torch.add]
>>> inspect.signature(my_add)
<Signature (input, other, out=None)>
"""
# Every function in the PyTorchAPI that can be overriden needs an entry
# in this dict.
#
# Optimally we would use inspect to get the function signature and define
# the lambda function procedurally but that is blocked by generating
# function signatures for native kernels that can be consumed by inspect.
# See Issue #28233.
Tensor = torch.Tensor
ret: Dict[Callable, Callable] = {
torch.abs: lambda input, out=None: -1,
torch.absolute: lambda input, out=None: -1,
torch.adaptive_avg_pool1d: lambda input, output_size: -1,
torch.adaptive_max_pool1d: lambda inputs, output_size: -1,
torch.acos: lambda input, out=None: -1,
torch.arccos: lambda input, out=None: -1,
torch.acosh: lambda input, out=None: -1,
torch.arccosh: lambda input, out=None: -1,
torch.add: lambda input, other, out=None: -1,
torch.addbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1,
torch.addcdiv: lambda input, tensor1, tensor2, value=1, out=None: -1,
torch.addcmul: lambda input, tensor1, tensor2, value=1, out=None: -1,
torch.addmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
torch.addmv: lambda input, mat, vec, beta=1, alpha=1, out=None: -1,
torch.addr: lambda input, vec1, vec2, beta=1, alpha=1, out=None: -1,
torch.affine_grid_generator: lambda theta, size, align_corners: -1,
torch.all: lambda input, dim=None: -1,
torch.allclose: lambda input, other, trol=1e-05, atol=1e-08, equal_nan=False: -1,
torch.alpha_dropout: lambda input, p, train, inplace=False: -1,
torch.amax: lambda input, dim=None: -1,
torch.amin: lambda input, dim=None: -1,
torch.angle: lambda input, out=None: -1,
torch.any: lambda input, dim=None, keepdim=False, out=None: -1,
torch.argmax: lambda input: -1,
torch.argmin: lambda input: -1,
torch.argsort: lambda input, dim=None: -1,
torch.asin: lambda input, out=None: -1,
torch._assert_async: lambda input: -1,
torch.arcsin: lambda input, out=None: -1,
torch.asinh: lambda input, out=None: -1,
torch.arcsinh: lambda input, out=None: -1,
torch.atan: lambda input, out=None: -1,
torch.arctan: lambda input, out=None: -1,
torch.atan2: lambda input, other, out=None: -1,
torch.atanh: lambda input, out=None: -1,
torch.arctanh: lambda input, out=None: -1,
torch.atleast_1d: lambda *tensors: -1,
torch.atleast_2d: lambda *tensors: -1,
torch.atleast_3d: lambda *tensors: -1,
torch.avg_pool1d: lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True: -1,
torch.baddbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1,
torch.batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled: -1,
torch.batch_norm_backward_elemt: lambda grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count_tensor: -1,
torch.batch_norm_backward_reduce: lambda grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g: -1,
torch.batch_norm_elemt: lambda input, weight, bias, mean, invstd, eps: -1,
torch.batch_norm_gather_stats: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1,
torch.batch_norm_gather_stats_with_counts: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1,
torch.batch_norm_stats: lambda input, eps: -1,
torch.batch_norm_update_stats: lambda input, running_mean, running_var, momentum: -1,
torch.bernoulli: lambda input, generator=None, out=None: -1,
torch.bilinear: lambda input1, input2, weight, bias: -1,
torch.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None, reduce=None,
reduction='mean', pos_weight=None: -1),
torch.bincount: lambda input, weights=None, minlength=0: -1,
torch.binomial: lambda count, prob, generator=None: -1,
torch.bitwise_and: lambda input, other, out=None: -1,
torch.bitwise_not: lambda input, out=None: -1,
torch.bitwise_or: lambda input, other, out=None: -1,
torch.bitwise_xor: lambda input, other, out=None: -1,
torch.bitwise_left_shift: lambda input, other, out=None: -1,
torch.bitwise_right_shift: lambda input, other, out=None: -1,
torch.block_diag: lambda *tensors: -1,
torch.bmm: lambda input, mat2, out=None: -1,
torch.broadcast_tensors: lambda *tensors: -1,
torch.broadcast_to: lambda self, size: -1,
torch.bucketize: lambda input, boundaries, out_int32=False, right=False, out=None: -1,
torch.cartesian_prod: lambda *tensors: -1,
torch.cat: lambda tensors, dim=0, out=None: -1,
torch.cdist: lambda x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary': -1,
torch.ceil: lambda input, out=None: -1,
torch.celu: lambda input, alhpa=1., inplace=False: -1,
torch.chain_matmul: lambda *matrices, out=None: -1,
torch.channel_shuffle: lambda input, groups : -1,
torch.cholesky: lambda input, upper=False, out=None: -1,
torch.linalg.cholesky: lambda input, out=None: -1,
torch.linalg.cholesky_ex: lambda input, check_errors=False, out=None: -1,
torch.cholesky_inverse: lambda input, upper=False, out=None: -1,
torch.cholesky_solve: lambda input1, input2, upper=False, out=None: -1,
torch.choose_qparams_optimized: lambda input, numel, n_bins, ratio, bit_width: -1,
torch.chunk: lambda input, chunks, dim=0: -1,
torch.clamp: lambda input, min=None, max=None, out=None: -1,
torch.clip: lambda input, min=None, max=None, out=None: -1,
torch.clamp_min: lambda input, min, out=None: -1,
torch.clamp_max: lambda input, max, out=None: -1,
torch.column_stack: lambda tensors, out=None: -1,
torch.clone: lambda input: -1,
torch.combinations: lambda input, r=2, with_replacement=False: -1,
torch.complex: lambda real, imag: -1,
torch.copysign: lambda input, other, out=None: -1,
torch.polar: lambda abs, ang: -1,
torch.linalg.cond: lambda input, ord=None: -1,
torch.conj: lambda input, out=None: -1,
torch.conj_physical: lambda input, out=None: -1,
torch.resolve_conj: lambda input, out=None: -1,
torch.constant_pad_nd: lambda input, pad, value=0: -1,
torch.conv1d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
torch.conv2d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
torch.conv3d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
torch.convolution: lambda input, weight, bias, stride, padding, dilation, transposed, output_adding, groups: -1,
torch.conv_tbc: lambda input, weight, bias, pad=0: -1,
torch.conv_transpose1d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
torch.conv_transpose2d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
torch.conv_transpose3d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
torch.cos: lambda input, out=None: -1,
torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1,
torch.cosh: lambda input, out=None: -1,
torch.cosine_similarity: lambda x1, x2, dim=1, eps=1e-8: -1,
torch.count_nonzero: lambda input: -1,
torch.cross: lambda input, other, dim=-1, out=None: -1,
torch.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean',
zero_infinity=False: -1),
torch.cummax: lambda input, dim, out=None: -1,
torch.cummin: lambda input, dim, out=None: -1,
torch.cumprod: lambda input, dim, out=None, dtype=None: -1,
torch.cumsum: lambda input, dim, out=None, dtype=None: -1,
torch.logcumsumexp: lambda input, dim, out=None: -1,
torch.deg2rad: lambda input, out=None: -1,
torch.dequantize: lambda input: -1,
torch.det: lambda input: -1,
torch.linalg.det: lambda input: -1, # alias for torch.det # type: ignore[attr-defined]
torch.detach: lambda input: -1,
torch.diag: lambda input, diagonal=0, out=None: -1,
torch.diag_embed: lambda input, diagonal=0, out=None: -1,
torch.diagflat: lambda input, offset=0: -1,
torch.diff: lambda input, n=1, dim=-1, prepend=None, append=None, out=None: -1,
torch.diagonal: lambda input, offset=0, dim1=0, dim2=1: -1,
torch.digamma: lambda input, out=None: -1,
torch.dist: lambda input, other, p=2: -1,
torch.div: lambda input, other, rounding_mode=None, out=None: -1,
torch.divide: lambda input, other, rounding_mode=None, out=None: -1,
torch.dot: lambda input, other, out=None: -1,
torch.dropout: lambda input, p, train, inplace=False: -1,
torch.dsmm: lambda input, mat2: -1,
torch.hsmm: lambda mat1, mat2: -1,
torch.dsplit: lambda input, indices_or_sections: -1,
torch.dstack: lambda tensors, out=None: -1,
torch.eig: lambda input, eigenvectors=False, out=None: -1,
torch.linalg.eig: lambda input, out=None: -1,
torch.linalg.eigvals: lambda input, out=None: -1,
torch.linalg.eigh: lambda input, UPLO="L", out=None: -1,
torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1,
torch.einsum: lambda equation, *operands: -1,
torch.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False,
sparse=False: -1),
torch.embedding_bag: (lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False,
mode='mean', sparse=False, per_sample_weights=None, padding_idx=None: -1),
torch.empty_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
torch.eq: lambda input, other, out=None: -1,
torch.equal: lambda input, other: -1,
torch.erf: lambda input, out=None: -1,
torch.erfc: lambda input, out=None: -1,
torch.erfinv: lambda input, out=None: -1,
torch.exp: lambda input, out=None: -1,
torch.exp2: lambda input, out=None: -1,
torch.expm1: lambda input, out=None: -1,
torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1,
torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1,
torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1,
torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1,
torch.fbgemm_linear_int8_weight_fp32_activation: (lambda input, weight, packed, col_offsets, weight_scale,
weight_zero_point, bias: -1),
torch.fbgemm_linear_quantize_weight: lambda input: -1,
torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1,
torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1,
torch.feature_alpha_dropout: lambda input, p, train: -1,
torch.feature_dropout: lambda input, p, train: -1,
torch.fft.fft: lambda input, n=None, dim=-1, norm=None: -1,
torch.fft.ifft: lambda input, n=None, dim=-1, norm=None: -1,
torch.fft.rfft: lambda input, n=None, dim=-1, norm=None: -1,
torch.fft.irfft: lambda input, n=None, dim=-1, norm=None: -1,
torch.fft.hfft: lambda input, n=None, dim=-1, norm=None: -1,
torch.fft.ihfft: lambda input, n=None, dim=-1, norm=None: -1,
torch.fft.fftn: lambda input, s=None, dim=None, norm=None: -1,
torch.fft.ifftn: lambda input, s=None, dim=None, norm=None: -1,
torch.fft.rfftn: lambda input, s=None, dim=None, norm=None: -1,
torch.fft.irfftn: lambda input, s=None, dim=None, norm=None: -1,
torch.fft.fft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
torch.fft.ifft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
torch.fft.rfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
torch.fft.irfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
torch.fft.fftshift: lambda input, dim=None: -1,
torch.fft.ifftshift: lambda input, dim=None: -1,
torch.fft.fft: lambda input, n=None, dim=-1, norm=None: -1,
torch.fix: lambda input, out=None: -1,
torch.flatten: lambda input, start_dim=0, end_dim=-1: -1,
torch.flip: lambda input, dims: -1,
torch.fliplr: lambda input: -1,
torch.flipud: lambda input: -1,
torch.frobenius_norm: lambda input, dim=None, keepdim=False, out=None: -1,
torch.floor: lambda input, out=None: -1,
torch.floor_divide: lambda input, other: -1,
torch.float_power: lambda input, exponent, out=None: -1,
torch.fmod: lambda input, other, out=None: -1,
torch.frac: lambda input, out=None: -1,
torch.frexp: lambda input, out=None: -1,
torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1,
torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1,
torch.gcd: lambda input, other, out=None: -1,
torch.ge: lambda input, other, out=None: -1,
torch.greater_equal: lambda input, other, out=None: -1,
torch.geqrf: lambda input, out=None: -1,
torch.i0: lambda input, out=None: -1,
torch.inner: lambda input, other, out=None: -1,
torch.outer: lambda input, vec2, out=None: -1,
torch.ger: lambda input, vec2, out=None: -1, # alias for torch.outer
torch.gradient: lambda input, spacing=None, dim=None, edge_order=1: -1,
torch.grid_sampler: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
torch.grid_sampler_2d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
torch.grid_sampler_3d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
torch.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05, cudnn_enabled=True: -1,
torch.gru: lambda input, hx, params, has_biases, num_layers, gropout, train, bidirectional, batch_first: -1,
torch.gru_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
torch.gt: lambda input, other, out=None: -1,
torch.greater: lambda input, other, out=None: -1,
torch.hardshrink: lambda input, lambd=0.5: -1,
torch.heaviside: lambda input, values, out=None: -1,
torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction='mean': -1,
torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
torch.histogram: lambda input, bins=100, min=None, max=None, weight=None, density=False, out=None: -1,
torch.linalg.householder_product: lambda input, tau: -1,
torch.hspmm: lambda mat1, mat2, out=None: -1,
torch.hsplit: lambda input, indices_or_sections: -1,
torch.hstack: lambda tensors, out=None: -1,
torch.hypot: lambda input, other, out=None: -1,
torch.igamma: lambda input, other, out=None: -1,
torch.igammac: lambda input, other, out=None: -1,
torch.imag: lambda input, out=None: -1,
torch.index_add: lambda input, dim, index, source: -1,
torch.index_copy: lambda input, dim, index, source: -1,
torch.index_put: lambda input, indices, values, accumulate=False: -1,
torch.index_select: lambda input, dim, index, out=None: -1,
torch.index_fill: lambda input, dim, index, value: -1,
torch.isfinite: lambda tensor: -1,
torch.isin: lambda e, te, assume_unique=False, invert=False: -1,
torch.isinf: lambda tensor: -1,
torch.isreal: lambda tensor: -1,
torch.isposinf: lambda input, out=None: -1,
torch.isneginf: lambda input, out=None: -1,
torch.instance_norm: (lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps,
cudnn_enabled: -1),
torch.int_repr: lambda input: -1,
torch.inverse: lambda input, out=None: -1,
torch.linalg.inv: lambda input, out=None: -1,
torch.linalg.inv_ex: lambda input, check_errors=False, out=None: -1,
torch.is_complex: lambda input: -1,
torch.is_conj: lambda input: -1,
torch.is_distributed: lambda input: -1,
torch.is_inference: lambda input: -1,
torch.is_floating_point: lambda input: -1,
torch.is_nonzero: lambda input: -1,
torch.is_same_size: lambda input, other: -1,
torch.is_signed: lambda input: -1,
torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1,
torch.isnan: lambda input: -1,
torch.istft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
normalized=False, onesided=None, length=None, return_complex=False: -1),
torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
torch.kron: lambda input, other: -1,
torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1,
torch.layer_norm: lambda input, normalized_shape, weight=None, bias=None, esp=1e-05, cudnn_enabled=True: -1,
torch.lcm: lambda input, other, out=None: -1,
torch.ldexp: lambda input, other, out=None: -1,
torch.le: lambda input, other, out=None: -1,
torch.less_equal: lambda input, other, out=None: -1,
torch.lerp: lambda input, end, weight, out=None: -1,
torch.lgamma: lambda input, out=None: -1,
torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None,
tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1,
torch.log: lambda input, out=None: -1,
torch.log_softmax: lambda input, dim, dtype=None: -1,
torch.log10: lambda input, out=None: -1,
torch.log1p: lambda input, out=None: -1,
torch.log2: lambda input, out=None: -1,
torch.logaddexp: lambda input, other, out=None: -1,
torch.logaddexp2: lambda input, other, out=None: -1,
torch.logdet: lambda input: -1,
torch.xlogy: lambda x, y: -1,
torch.logical_and: lambda input, other, out=None: -1,
torch.logical_not: lambda input, out=None: -1,
torch.logical_or: lambda input, other, out=None: -1,
torch.logical_xor: lambda input, other, out=None: -1,
torch.logsumexp: lambda input, names, keepdim=False, out=None: -1,
torch.logit: lambda input, eps=None: -1,
torch.logsumexp: lambda input, names, keepdim=False, out=None: -1,
torch.lstm: lambda data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional: -1,
torch.lstm_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
torch.lstsq: lambda input, A, out=None: -1,
torch.lt: lambda input, other, out=None: -1,
torch.less: lambda input, other, out=None: -1,
torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1,
torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1,
torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1, # type: ignore[attr-defined] # noqa: B950
torch.masked_fill: lambda input, mask, value: -1,
torch.masked_scatter: lambda input, mask, source: -1,
torch.masked_select: lambda input, mask, out=None: -1,
torch.matmul: lambda input, other, out=None: -1,
torch.matrix_power: lambda input, n: -1,
torch.linalg.matrix_power: lambda input, n, out=None: -1,
torch.matrix_rank: lambda input, tol=None, symmetric=False: -1,
torch.linalg.matrix_rank: lambda input, tol=None, hermitian=False: -1,
torch.linalg.multi_dot: lambda tensors, out=None: -1,
torch.matrix_exp: lambda input: -1,
torch.max: lambda input, out=None: -1,
torch.maximum: lambda input, other, out=None: -1,
torch.fmax: lambda input, other, out=None: -1,
torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
torch.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
return_indices=False, ceil_mode=False: -1),
torch.mean: lambda input, dim=None: -1,
torch.median: lambda input, dim=None: -1,
torch.nanmedian: lambda input, dim=None: -1,
torch.meshgrid: lambda *tensors, **kwargs: -1,
torch.min: lambda input, out=None: -1,
torch.minimum: lambda input, other, out=None: -1,
torch.fmin: lambda input, other, out=None: -1,
torch.miopen_batch_norm: (lambda input, weight, bias, running_mean, running_var, training,
exponential_average_factor, epsilon: -1),
torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1,
torch.miopen_convolution_transpose: (lambda input, weight, bias, padding, output_padding, stride, dilation,
groups, benchmark, deterministic: -1),
torch.miopen_depthwise_convolution: (lambda input, weight, bias, padding, stride, dilation, groups, benchmark,
deterministic: -1),
torch.miopen_rnn: (lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first,
dropout, train, bidirectional, batch_sizes, dropout_state: -1),
torch.mm: lambda input, mat2, out=None: -1,
torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
torch.movedim: lambda input, source, destination: -1,
torch.moveaxis: lambda input, source, destination: -1,
torch.msort: lambda input, descending=False, out=None: -1,
torch.mul: lambda input, other, out=None: -1,
torch.multiply: lambda input, other, out=None: -1,
torch.multinomial: lambda input, num_samples, replacement=False, out=None: -1,
torch.mv: lambda input, vec, out=None: -1,
torch.mvlgamma: lambda input, p: -1,
torch.narrow: lambda input, dim, start, length: -1,
torch.narrow_copy: lambda input, dim, start, length: -1,
torch.nan_to_num: lambda input, nan=0.0, posinf=None, neginf=None, out=None: -1,
torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1,
torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
torch.native_norm: lambda input, p=2: -1,
torch.native_norm: lambda input, p=2: -1,
torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1,
torch.ne: lambda input, other, out=None: -1,
torch.not_equal: lambda input, other, out=None: -1,
torch.neg: lambda input, out=None: -1,
torch.negative: lambda input, out=None: -1,
torch.nextafter: lambda input, other, out=None: -1,
torch.nn.functional.adaptive_avg_pool2d: lambda input, output_size: -1,
torch.nn.functional.adaptive_avg_pool3d: lambda input, output_size: -1,
torch.nn.functional.adaptive_max_pool1d: lambda input, output_size, return_indices=False: -1,
torch.nn.functional.adaptive_max_pool1d_with_indices: lambda input, output_size, return_indices=False: -1,
torch.nn.functional.adaptive_max_pool2d: lambda input, output_size, return_indices=False: -1,
torch.nn.functional.adaptive_max_pool2d_with_indices: lambda input, output_size, return_indices=False: -1,
torch.nn.functional.adaptive_max_pool3d: lambda input, output_size, return_indices=False: -1,
torch.nn.functional.adaptive_max_pool3d_with_indices: lambda input, output_size, return_indices=False: -1,
torch.nn.functional.affine_grid: lambda theta, size, align_corners=None: -1,
torch.nn.functional.alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
torch.nn.functional.avg_pool2d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False,
count_include_pad=True, divisor_override=None: -1),
torch.nn.functional.avg_pool3d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False,
count_include_pad=True, divisor_override=None: -1),
torch.nn.functional.batch_norm: (lambda input, running_mean, running_var, weight=None, bias=None, training=False,
momentum=0.1, eps=1e-05: -1),
torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1,
torch.nn.functional.binary_cross_entropy: (lambda input, target, weight=None, size_average=None, reduce=None,
reduction="mean": -1),
torch.nn.functional.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None,
reduce=None, reduction="mean", pos_weight=None: -1),
torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1,
torch.nn.functional.cosine_embedding_loss: (lambda input1, input2, target, margin=0, size_average=None,
reduce=None, reduction='mean': -1),
torch.nn.functional.cross_entropy: (lambda input, target, weight=None, size_average=None, ignore_index=-100,
reduce=None, reduction="mean": -1),
torch.nn.functional.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0,
reduction='mean', zero_infinity=False: -1),
torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1,
torch.nn.functional.dropout2d: lambda input, p=0.5, training=True, inplace=False: -1,
torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1,
torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1,
torch.nn.functional.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0,
scale_grad_by_freq=False, sparse=False: -1),
torch.nn.functional.embedding_bag: (lambda input, weight, offsets=None, max_norm=None, norm_type=2,
scale_grad_by_freq=False, mode='mean', sparse=False, per_sample_weights=None,
include_last_offset=False, padding_idx=None: -1),
torch.nn.functional.feature_alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
torch.nn.functional.fold: lambda input, output_size, kernel_size, dilation=1, padding=0, stride=1: -1,
torch.nn.functional.fractional_max_pool2d: (lambda input, kernel_size, output_size=None, output_ratio=None,
return_indices=False, _random_samples=None: -1),
torch.nn.functional.fractional_max_pool2d_with_indices: (
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
_random_samples=None: -1),
torch.nn.functional.fractional_max_pool3d: (lambda input, kernel_size, output_size=None, output_ratio=None,
return_indices=False, _random_samples=None: -1),
torch.nn.functional.fractional_max_pool3d_with_indices: (
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
_random_samples=None: -1),
torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction='mean': -1,
torch.nn.functional.gelu: lambda input: -1,
torch.nn.functional.glu: lambda input, dim=-1: -1,
torch.nn.functional.grid_sample: lambda input, grid, mode='bilinear', padding_mode='zeros', align_corners=None: -1,
torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1,
torch.nn.functional.gumbel_softmax: lambda logits, tau=1, hard=False, eps=1e-10, dim=-1: -1,
torch.nn.functional.hardshrink: lambda input, lambd=0.5: -1,
torch.nn.functional.hardtanh: lambda input, min_val=-1., max_val=1., inplace=False: -1,
torch.nn.functional.hinge_embedding_loss: (lambda input, target, margin=1.0, size_average=None, reduce=None,
reduction='mean': -1),
torch.nn.functional.instance_norm: (lambda input, running_mean=None, running_var=None, weight=None, bias=None,
use_input_stats=True, momentum=0.1, eps=1e-05: -1),
torch.nn.functional.interpolate: (lambda input, size=None, scale_factor=None, mode='nearest', align_corners=None,
recompute_scale_factor=None: -1),
torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1,
torch.nn.functional.linear: lambda input, weight, bias=None: -1,
torch.nn.functional.local_response_norm: lambda input, size, alpha=0.0001, beta=0.75, k=1.0: -1,
torch.nn.functional.log_softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
torch.nn.functional.logsigmoid: lambda input: -1,
torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
torch.nn.functional.margin_ranking_loss: (lambda input1, input2, target, margin=0, size_average=None,
reduce=None, reduction='mean': -1),
torch.nn.functional.max_pool1d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
ceil_mode=False, return_indices=False: -1),
torch.nn.functional.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
return_indices=False, ceil_mode=False: -1),
torch.nn.functional.max_pool2d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
ceil_mode=False, return_indices=False: -1),
torch.nn.functional.max_pool2d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
return_indices=False, ceil_mode=False: -1),
torch.nn.functional.max_pool3d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
return_indices=False, ceil_mode=False: -1),
torch.nn.functional.max_pool3d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
return_indices=False, ceil_mode=False: -1),
torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
torch.nn.functional.multi_head_attention_forward: (
lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v,
add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None,
need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None,
v_proj_weight=None, static_k=None, static_v=None: -1),
torch.nn.functional.multi_margin_loss: (lambda input, target, p=1, margin=1.0, weight=None, size_average=None,
reduce=None, reduction='mean': -1),
torch.nn.functional.multilabel_margin_loss: (lambda input, target, size_average=None, reduce=None,
reduction='mean': -1),
torch.nn.functional.multilabel_soft_margin_loss: (lambda input, target, weight=None, size_average=None,
reduce=None, reduction='mean': -1),
torch.nn.functional.nll_loss: (lambda input, target, weight=None, size_average=None, ignore_index=-100,
reduce=None, reduction='mean': -1),
torch.nn.functional.normalize: lambda input, p=2, dim=1, eps=1e-12, out=None: -1,
torch.nn.functional.one_hot: lambda tensor, num_classes=-1: -1,
torch.nn.functional.pad: lambda input, pad, mode='constant', value=0: -1,
torch.nn.functional.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
torch.nn.functional.poisson_nll_loss: (lambda input, target, log_input=True, full=False, size_average=None,
eps=1e-08, reduce=None, reduction='mean': -1),
torch.nn.functional.prelu: lambda input, weight: -1,
torch.nn.functional.relu: lambda input, inplace=False: -1,
torch.nn.functional.relu6: lambda input, inplace=False: -1,
torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1,
torch.nn.functional.selu: lambda input, inplace=False: -1,
torch.nn.functional.silu: lambda input, inplace=False: -1,
torch.nn.functional.mish: lambda input, inplace=False: -1,
torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean', beta=1.: -1,
torch.nn.functional.huber_loss: lambda input, target, reduction='mean', delta=1.: -1,
torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
torch.nn.functional.softplus: lambda input, beta=1, threshold=20: -1,
torch.nn.functional.softshrink: lambda input, lambd=0.5: -1,
torch.nn.functional.softsign: lambda input: -1,
torch.nn.functional.tanhshrink: lambda input: -1,
torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1,
torch.nn.functional.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06,
swap=False, size_average=None, reduce=None, reduction='mean': -1),
torch.nn.functional.triplet_margin_with_distance_loss: (lambda anchor, positive, negative, *,
distance_function=None, margin=1.0,
swap=False, reduction='mean': -1),
torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
torch.nonzero: lambda input, as_tuple=False: -1,
torch.norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
torch.linalg.norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1,
torch.linalg.vector_norm: lambda input, ord=2, dim=None, keepdim=False, out=None, dtype=None: -1,
torch.linalg.matrix_norm: lambda input, ord='fro', dim=(-2, -1), keepdim=False, out=None, dtype=None: -1,
torch.norm_except_dim: lambda v, pow=2, dim=0: -1,
torch.nuclear_norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
torch.numel: lambda input: -1,
torch.orgqr: lambda input, tau: -1,
torch.ormqr: lambda input, input2, input3, left=True, transpose=False: -1,
torch.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
torch.permute: lambda self, dim: -1,
torch.pca_lowrank: lambda input, q=None, center=True, niter=2: -1,
torch.pdist: lambda input, p=2: -1,
torch.pinverse: lambda input, rcond=1e-15: -1,
torch.linalg.pinv: lambda input, rcond=1e-15, hermitian=False: -1,
torch.pixel_shuffle: lambda input, upscale_factor: -1,
torch.pixel_unshuffle: lambda input, downscale_factor: -1,
torch.poisson: lambda input, generator=None: -1,
torch.poisson_nll_loss: lambda input, target, log_input, full, eps, reduction: -1,
torch.polygamma: lambda input, n, out=None: -1,
torch.positive: lambda input, out=None: -1,
torch.prelu: lambda input, weight: -1,
torch.ones_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
torch.pow: lambda input, exponent, out=None: -1,
torch.prod: lambda input, dtype=None: -1,
torch.put: lambda input, index, source, accumulate=False: -1,
torch.q_per_channel_axis: lambda input: -1,
torch.q_per_channel_scales: lambda input: -1,
torch.q_per_channel_zero_points: lambda input: -1,
torch.q_scale: lambda input: -1,
torch.q_zero_point: lambda input: -1,
torch.qr: lambda input, some=True, out=None: -1,
torch.linalg.qr: lambda input, mode='reduced', out=None: -1,
torch.quantile: lambda input, q, dim=None, keepdim=False, out=None: -1,
torch.nanquantile: lambda input, q, dim=None, keepdim=False, out=None: -1,
torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1,
torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1,
torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1,
torch.quantized_gru_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
torch.quantized_lstm_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
torch.quantized_max_pool1d: (lambda input, kernel_size, stride=tuple(), padding=(0,),
dilation=(1,), ceil_mode=False: -1),
torch.quantized_max_pool2d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0),
dilation=(1, 1), ceil_mode=False: -1),
torch.quantized_rnn_relu_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
torch.quantized_rnn_tanh_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
torch.rad2deg: lambda input, out=None: -1,
torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
torch.randn_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
torch.ravel: lambda input: -1,
torch.real: lambda input, out=None: -1,
torch.vdot: lambda input, other, out=None: -1,
torch.view_as_real: lambda input: -1,
torch._view_as_real_physical: lambda input: -1,
torch.view_as_complex: lambda input: -1,
torch.reciprocal: lambda input, out=None: -1,
torch.relu: lambda input, inplace=False: -1,
torch.remainder: lambda input, other, out=None: -1,
torch.renorm: lambda input, p, dim, maxnorm, out=None: -1,
torch.repeat_interleave: lambda input, dim=None: -1,
torch.reshape: lambda input, shape: -1,
torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
torch.rnn_tanh_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
torch.roll: lambda input, shifts, dims=None: -1,
torch.rot90: lambda input, k=1, dims=(0, 1): -1,
torch.round: lambda input, out=None: -1,
torch.row_stack: lambda tensors, out=None: -1, # alias for torch.vstack
torch._rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1),
torch.rrelu: lambda input, lower=1. / 8, upper=1. / 3, training=False, inplace=False: -1,
torch.rsqrt: lambda input, out=None: -1,
torch.rsub: lambda input, other, alpha=1: -1,
torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
torch.scatter: lambda input, dim, index, src: -1,
torch.scatter_add: lambda input, dim, index, src: -1,
torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, axis=0, unsafe=False: -1,
torch.select: lambda input, dim, index: -1,
torch.selu: lambda input, inplace=False: -1,
torch.sigmoid: lambda input, out=None: -1,
torch.sign: lambda input, out=None: -1,
torch.signbit: lambda input, out=None: -1,
torch.sgn: lambda input, out=None: -1,
torch.sin: lambda input, out=None: -1,
torch.sinc: lambda input, out=None: -1,
torch.sinh: lambda input, out=None: -1,
torch.slogdet: lambda input: -1,
torch.linalg.slogdet: lambda input: -1,
torch.smm: lambda input, mat2: -1,
torch.spmm: lambda input, mat2: -1,
torch.softmax: lambda input, dim, dtype=None: -1,
torch.solve: lambda input, A, out=None: -1,
torch.linalg.solve: lambda input, other, out=None: -1,
torch.sort: lambda input, dim=-1, descending=False, *, stable=False, out=None: -1,
torch.split: lambda tensor, split_size_or_sections, dim=0: -1,
torch.split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1,
torch.sqrt: lambda input, out=None: -1,
torch.square: lambda input, out=None: -1,
torch.squeeze: lambda input, dim=None, out=None: -1,
torch.sspaddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
torch.stack: lambda tensors, dim=0, out=None: -1,
torch.std: lambda input, dim=None: -1,
torch.std_mean: lambda input, dim=None: -1,
torch.stft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
pad_mode='reflect', normalized=False, onesided=True, return_complex=None: -1),
torch.sub: lambda input, other, out=None: -1,
torch.subtract: lambda input, other, out=None: -1,
torch.sum: lambda input, dim=None: -1,
torch.nansum: lambda input, dim=None: -1,
torch.svd: lambda input, some=True, compute_uv=True, out=None: -1,
torch.svd_lowrank: lambda input, q=6, niter=2, M=None: -1,
torch.linalg.svd: lambda input, full_matrices=True, out=None: -1,
torch.linalg.svdvals: lambda input, out=None: -1,
torch.symeig: lambda input, eigenvectors=False, upper=True, out=None: -1,
torch.swapaxes: lambda input, dim0, dim1: -1,
torch.swapdims: lambda input, axis0, axis1: -1,
torch.special.entr: lambda input: -1,
torch.special.erf: lambda input: -1,
torch.special.erfc: lambda input: -1,
torch.special.erfcx: lambda input: -1,
torch.special.erfinv: lambda input: -1,
torch.special.exp2: lambda input: -1,
torch.special.expm1: lambda input: -1,
torch.special.expit: lambda input: -1,
torch.special.digamma: lambda input: -1,
torch.special.psi: lambda input: -1,
torch.special.gammaln: lambda input: -1,
torch.special.i0: lambda input: -1,
torch.special.i0e: lambda input: -1,
torch.special.i1: lambda input: -1,
torch.special.i1e: lambda input: -1,
torch.special.logit: lambda input: -1,
torch.special.log1p: lambda input: -1,
torch.special.round: lambda input: -1,
torch.special.sinc: lambda input: -1,
torch.special.ndtri: lambda input: -1,
torch.special.ndtr: lambda input: -1,
torch.special.xlog1py: lambda input, other, out=None: -1,
torch.special.zeta: lambda self, other, out=None: -1,
torch.t: lambda input: -1,
torch.take: lambda input, index: -1,
torch.take_along_dim: lambda input, indices, dim=None, out=None: -1,
torch.tan: lambda input, out=None: -1,
torch.tanh: lambda input, out=None: -1,
torch.linalg.tensorinv: lambda a, ind=2: -1,
torch.linalg.tensorsolve: lambda a, b, dims=None: -1,
torch.tensordot: lambda a, b, dims=2, out=None: -1,
torch.tensor_split: lambda input, indices_or_sections, dim=0: -1,
torch.threshold: lambda input, threshold, value, inplace=False: -1,
torch.tile: lambda input, dims: -1,
torch.topk: lambda input, k, dim=-1, descending=False, out=None: -1,
torch.trace: lambda input: -1,
torch.transpose: lambda input, dim0, dim1: -1,
torch.trapz: lambda y, x=None, dim=-1: -1,
torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1,
torch.tril: lambda input, diagonal=0, out=None: -1,
torch.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False,
size_average=None, reduce=None, reduction='mean': -1),
torch.triu: lambda input, diagonal=0, out=None: -1,
torch.true_divide: lambda input, other: -1,
torch.trunc: lambda input, out=None: -1,
torch.unbind: lambda input, dim=0: -1,
torch.unique: lambda input, sorted=True, return_inverse=False, return_counts=False, dim=None: -1,
torch.unique_consecutive: lambda input, return_inverse=False, return_counts=False, dim=None: -1,
torch.unsafe_chunk: lambda input, chunks, dim=0: -1,
torch.unsafe_split: lambda tensor, split_size_or_sections, dim=0: -1,
torch.unsafe_split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1,
torch.unsqueeze: lambda input, dim, out=None: -1,
torch.var: lambda input, dim=None: -1,
torch.var_mean: lambda input, dim=None: -1,
torch.vsplit: lambda input, indices_or_sections: -1,
torch.vstack: lambda tensors, out=None: -1,
torch.where: lambda condition, x=None, y=None: -1,
torch.zeros_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
Tensor.__floordiv__: lambda self, other: -1,
Tensor.__rfloordiv__: lambda self, other: -1,
Tensor.__ifloordiv__: lambda self, other: -1,
Tensor.__truediv__: lambda self, other: -1,
Tensor.__rtruediv__: lambda self, other: -1,
Tensor.__itruediv__: lambda self, other: -1,
Tensor.__lshift__: lambda self, other: -1,
Tensor.__rlshift__: lambda self, other: -1,
Tensor.__ilshift__: lambda self, other: -1,
Tensor.__rshift__: lambda self, other: -1,
Tensor.__rrshift__: lambda self, other: -1,
Tensor.__irshift__: lambda self, other: -1,
Tensor.__float__: lambda self: -1,
Tensor.__complex__: lambda self: -1,
Tensor.__array__: lambda self, dtype: -1,
Tensor.__bool__: lambda self: -1,
Tensor.__contains__: lambda self, other: -1,
Tensor.__neg__: lambda self: -1,
Tensor.__invert__: lambda self: -1,
Tensor.__mod__: lambda self, other: -1,
Tensor.__rmod__: lambda self, other: -1,
Tensor.__imod__: lambda self, other: -1,
Tensor.__array_wrap__: lambda self, array: -1,
Tensor.__getitem__: lambda self, idx: -1,
Tensor.__deepcopy__: lambda self, memo: -1,
Tensor.__int__: lambda self: -1,
Tensor.__long__: lambda self: -1,
Tensor.__hash__: lambda self: -1,
Tensor.__index__: lambda self: -1,
Tensor.__len__: lambda self: -1,
Tensor.__format__: lambda self, format_spec: -1,
Tensor.__reduce_ex__: lambda self, proto: -1,
Tensor.__reversed__: lambda self: -1,
Tensor.__repr__: lambda self: -1,
Tensor.__setitem__: lambda self, k, v: -1,
Tensor.__setstate__: lambda self, d: -1,
Tensor.T.__get__: lambda self: -1,
Tensor._backward_hooks.__get__: lambda self: -1,
Tensor._base.__get__: lambda self: -1,
Tensor._cdata.__get__: lambda self: -1,
Tensor.grad.__get__: lambda self: -1,
Tensor._grad.__get__: lambda self: -1,
Tensor._grad_fn.__get__: lambda self: -1,
Tensor.grad_fn.__get__: lambda self: -1,
Tensor._version.__get__: lambda self: -1,
Tensor.data.__get__: lambda self: -1,
Tensor.device.__get__: lambda self: -1,
Tensor.dtype.__get__: lambda self: -1,
Tensor.is_cuda.__get__: lambda self: -1,
Tensor.is_xpu.__get__: lambda self: -1,
Tensor.is_leaf.__get__: lambda self: -1,
Tensor.retains_grad.__get__: lambda self: -1,
Tensor.is_meta.__get__: lambda self: -1,
Tensor.is_mlc.__get__: lambda self: -1,
Tensor.is_mkldnn.__get__: lambda self: -1,
Tensor.is_quantized.__get__: lambda self: -1,
Tensor.is_sparse.__get__: lambda self: -1,
Tensor.is_sparse_csr.__get__: lambda self: -1,
Tensor.is_vulkan.__get__: lambda self: -1,
Tensor.layout.__get__: lambda self: -1,
Tensor.name.__get__: lambda self: -1,
Tensor.names.__get__: lambda self: -1,
Tensor.ndim.__get__: lambda self: -1,
Tensor.output_nr.__get__: lambda self: -1,
Tensor.requires_grad.__get__: lambda self: -1,
Tensor.shape.__get__: lambda self: -1,
Tensor.volatile.__get__: lambda self: -1,
Tensor.real.__get__: lambda self: -1,
Tensor.imag.__get__: lambda self: -1,
Tensor.__cuda_array_interface__.__get__: lambda self: -1,
Tensor.type: lambda self, dtype=None, non_blocking=False, **kwargs: -1,
Tensor._coalesced_: lambda self: -1,
Tensor._dimI: lambda self: -1,
Tensor._dimV: lambda self: -1,
Tensor._indices: lambda self: -1,
Tensor._is_view: lambda self: -1,
Tensor._nnz: lambda self: -1,
Tensor.crow_indices: lambda self: -1,
Tensor.col_indices: lambda self: -1,
Tensor._update_names: lambda self, names, inplace: -1,
Tensor._values: lambda self: -1,
Tensor.align_as: lambda self, other: -1,
Tensor.align_to: lambda self, order, ellipsis_idx: -1,
Tensor.apply_: lambda self, callable: -1,
Tensor.as_strided: lambda self, size, stride: -1,
Tensor.as_strided_: lambda self, size, stride: -1,
Tensor.backward: lambda self, gradient=None, retain_graph=None, create_graph=False, inputs=None: -1,
Tensor.bfloat16: lambda self, memory_format=torch.preserve_format: -1,
Tensor.bool: lambda self, memory_format=torch.preserve_format: -1,
Tensor.byte: lambda self, memory_format=torch.preserve_format: -1,
Tensor.char: lambda self, memory_format=torch.preserve_format: -1,
Tensor.cauchy_: lambda self, median=0, sigma=1, *, generator=None: -1,
Tensor.coalesce: lambda self: -1,
Tensor._coalesced_: lambda self, coalesced: -1,
Tensor.contiguous: lambda self, memory_format=torch.contiguous_format: -1,
Tensor.copy_: lambda self, src, non_blocking=False: -1,
Tensor.cpu: lambda self, memory_format=torch.preserve_format: -1,
Tensor.cuda: lambda self, memory_format=torch.preserve_format: -1,
Tensor.xpu: lambda self, memory_format=torch.preserve_format: -1,
Tensor.data_ptr: lambda self: -1,
Tensor.dense_dim: lambda self: -1,
Tensor.dim: lambda self: -1,
Tensor.double: lambda self, memory_format=torch.preserve_format: -1,
Tensor.cdouble: lambda self, memory_format=torch.preserve_format: -1,
Tensor.element_size: lambda self: -1,
Tensor.expand: lambda self, size: -1,
Tensor.expand_as: lambda self, other: -1,
Tensor.exponential_: lambda self, lambd=1, *, generator=None: -1,
Tensor.fill_: lambda self, value: -1,
Tensor.fill_diagonal_: lambda self, value: -1,
Tensor.float: lambda self, memory_format=torch.preserve_format: -1,
Tensor.cfloat: lambda self, memory_format=torch.preserve_format: -1,
Tensor.geometric_: lambda self, p, *, generator=None: -1,
Tensor.get_device: lambda self: -1,
Tensor.half: lambda self, memory_format=torch.preserve_format: -1,
Tensor.has_names: lambda self: -1,
Tensor.indices: lambda self: -1,
Tensor.int: lambda self, memory_format=torch.preserve_format: -1,
Tensor.is_coalesced: lambda self: -1,
Tensor.is_contiguous: lambda self: -1,
Tensor.is_inference: lambda self: -1,
Tensor.is_pinned: lambda self: -1,
Tensor.is_set_to: lambda self, tensor: -1,
Tensor.is_shared: lambda self: -1,
Tensor.item: lambda self: -1,
Tensor.log_normal_: lambda self, mean=1, std=2, *, generator=None: -1,
Tensor.log_softmax: lambda self, dim: -1,
Tensor.long: lambda self, memory_format=torch.preserve_format: -1,
Tensor.map_: lambda self, tensor, callable: -1,
Tensor.map2_: lambda self, x, y, callable: -1,
Tensor.mm: lambda self, mat2: -1,
Tensor.narrow_copy: lambda self, dimension, start, length: -1,
Tensor.ndimension: lambda self: -1,
Tensor.nelement: lambda self: -1,
Tensor.normal_: lambda self: -1,
Tensor.numpy: lambda self: -1,
Tensor.permute: lambda self, dim: -1,
Tensor.pin_memory: lambda self: -1,
Tensor.put_: lambda self, indices, tensor, accumulate=False: -1,
Tensor.qscheme: lambda self: -1,
Tensor.random_: lambda self, from_=0, to=None, *, generator=None: -1,
Tensor.record_stream: lambda self, stream: -1,
Tensor.refine_names: lambda self, names: -1,
Tensor.register_hook: lambda self, hook: -1,
Tensor.rename: lambda self, name: -1,
Tensor.repeat: lambda self, *size: -1,
Tensor.requires_grad_: lambda self, requires_grad=True: -1,
Tensor.reshape_as: lambda self, other: -1,
Tensor.resize: lambda self, *size: -1,
Tensor.resize_: lambda self, size: -1,
Tensor.resize_as: lambda self, other: -1,
Tensor.retain_grad: lambda self: -1,
Tensor.set_: lambda self, source=None, storage_offset=0, size=None, stride=None: -1,
Tensor.share_memory_: lambda self: -1,
Tensor.short: lambda self, memory_format=torch.preserve_format: -1,
Tensor.size: lambda self: -1,
Tensor.sparse_dim: lambda self: -1,
Tensor.sparse_mask: lambda self, mask: -1,
Tensor.sparse_resize_: lambda self, size1, size2, dense_dim: -1,
Tensor.sparse_resize_and_clear_: lambda self, size1, size2, dense_dim: -1,
Tensor.sspaddmm: lambda self, mat1, mat2, beta=1, alpha=1, out=None: -1,
Tensor.storage: lambda self: -1,
Tensor.storage_offset: lambda self: -1,
Tensor.storage_type: lambda self: -1,
Tensor.sum_to_size: lambda self, size: -1,
Tensor.tile: lambda self, *reps: -1,
Tensor.to: lambda self, dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format: -1,
Tensor.to_dense: lambda self: -1,
Tensor.to_sparse: lambda self: -1,
Tensor.tolist: lambda self: -1,
Tensor.to_mkldnn: lambda self: -1,
Tensor.type_as: lambda self, other: -1,
Tensor.unfold: lambda self, dimension, size, step: -1,
Tensor.uniform_: lambda self, from_=0, to=1: -1,
Tensor.values: lambda self: -1,
Tensor.view: lambda self, shape: -1,
Tensor.view_as: lambda self, other: -1,
Tensor.zero_: lambda self: -1,
torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1,
}
ret2 = {}
ignored = get_ignored_functions()
for k, v in ret.items():
# Generate methods like __add__ and add_ by default from add
names = [
k.__name__, # Default method
k.__name__ + "_", # Inplace variant
"__" + k.__name__ + "__", # Dunder method
"__i" + k.__name__ + "__", # Inplace dunder method
"__r" + k.__name__ + "__", # Reverse dunder method
]
if k.__name__.startswith("bitwise_"):
# bitwise_<op> have dunder methods of the form __<op>__
# And so on.
subname = k.__name__[len("bitwise_"):]
names.extend([
"__" + subname + "__",
"__i" + subname + "__",
"__r" + subname + "__"
])
for name in names:
func = getattr(Tensor, name, None)
if callable(func) and func not in ret and func not in ignored:
ret2[func] = v
ret.update(ret2)
return ret
def wrap_torch_function(dispatcher: Callable):
"""Wraps a given function with ``__torch_function__`` -related functionality.
Parameters
----------
dispatcher: Callable
A callable that returns an iterable of Tensor-likes passed into the function.
Note
----
This decorator may reduce the performance of your code. Generally, it's enough to express
your code as a series of functions that, themselves, support __torch_function__. If you
find yourself in the rare situation where this is not the case, e.g. if you're wrapping a
low-level library and you also need it to work for Tensor-likes, then this function is available.
Examples
--------
>>> def dispatcher(a): # Must have the same signature as func
... return (a,)
>>> @torch.overrides.wrap_torch_function(dispatcher)
>>> def func(a): # This will make func dispatchable by __torch_function__
... return a + 0
"""
def inner(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
relevant_args = dispatcher(*args, **kwargs)
if has_torch_function(relevant_args):
return handle_torch_function(func, relevant_args, *args, **kwargs)
return func(*args, **kwargs)
return wrapped
return inner
def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]:
"""Returns a list of arguments on which to call __torch_function__.
Checks arguments in relevant_args for __torch_function__ implementations,
storing references to the arguments and their types in overloaded_args and
overloaded_types in order of calling precedence. Only distinct types are
considered. If a type is a subclass of another type it will have higher
precedence, otherwise the precedence order is the same as the order of
arguments in relevant_args, that is, from left-to-right in the argument list.
The precedence-determining algorithm implemented in this function is
described in `NEP-0018`_.
See torch::append_overloaded_arg for the equivalent function in the C++
implementation.
Parameters
----------
relevant_args : iterable of array-like
Iterable of array-like arguments to check for __torch_function__
methods.
Returns
-------
overloaded_args : list
Arguments from relevant_args on which to call __torch_function__
methods, in the order in which they should be called.
.. _NEP-0018:
https://numpy.org/neps/nep-0018-array-function-protocol.html
"""
# Runtime is O(num_arguments * num_unique_types)
overloaded_types: Set[Type] = set()
overloaded_args: List[Any] = []
for arg in relevant_args:
arg_type = type(arg)
# We only collect arguments if they have a unique type, which ensures
# reasonable performance even with a long list of possibly overloaded
# arguments.
if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__')):
# Create lists explicitly for the first type (usually the only one
# done) to avoid setting up the iterator for overloaded_args.
if overloaded_types:
overloaded_types.add(arg_type)
# By default, insert argument at the end, but if it is
# subclass of another argument, insert it before that argument.
# This ensures "subclasses before superclasses".
index = len(overloaded_args)
for i, old_arg in enumerate(overloaded_args):
if issubclass(arg_type, type(old_arg)):
index = i
break
overloaded_args.insert(index, arg)
else:
overloaded_types = {arg_type}
overloaded_args = [arg]
return overloaded_args
def handle_torch_function(
public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs) -> Any:
"""Implement a function with checks for ``__torch_function__`` overrides.
See torch::autograd::handle_torch_function for the equivalent of this
function in the C++ implementation.
Arguments
---------
public_api : function
Function exposed by the public torch API originally called like
``public_api(*args, **kwargs)`` on which arguments are now being
checked.
relevant_args : iterable
Iterable of arguments to check for __torch_function__ methods.
args : tuple
Arbitrary positional arguments originally passed into ``public_api``.
kwargs : tuple
Arbitrary keyword arguments originally passed into ``public_api``.
Returns
-------
object
Result from calling ``implementation`` or an ``__torch_function__``
method, as appropriate.
Raises
------
TypeError : if no implementation is found.
Example
-------
>>> def func(a):
... if type(a) is not torch.Tensor: # This will make func dispatchable by __torch_function__
... return handle_torch_function(func, (a,), a)
... return a + 0
"""
# Check for __torch_function__ methods.
overloaded_args = _get_overloaded_args(relevant_args)
# overloaded_args already have unique types.
types = tuple(map(type, overloaded_args))
# Call overrides
for overloaded_arg in overloaded_args:
# Use `public_api` instead of `implementation` so __torch_function__
# implementations can do equality/identity comparisons.
result = overloaded_arg.__torch_function__(public_api, types, args, kwargs)
if result is not NotImplemented:
return result
func_name = '{}.{}'.format(public_api.__module__, public_api.__name__)
raise TypeError("no implementation found for '{}' on types that implement "
'__torch_function__: {}'
.format(func_name, [type(arg) for arg in overloaded_args]))
has_torch_function = _add_docstr(
_has_torch_function,
r"""Check for __torch_function__ implementations in the elements of an iterable.
Considers exact ``Tensor`` s and ``Parameter`` s non-dispatchable.
Arguments
---------
relevant_args : iterable
Iterable or aguments to check for __torch_function__ methods.
Returns
-------
bool
True if any of the elements of relevant_args have __torch_function__
implementations, False otherwise.
See Also
________
torch.is_tensor_like
Checks if something is a Tensor-like, including an exact ``Tensor``.
"""
)
has_torch_function_unary = _add_docstr(
_has_torch_function_unary,
r"""Special case of `has_torch_function` for single inputs.
Instead of:
`has_torch_function((t,))`
call:
`has_torch_function_unary(t)`
which skips unnecessary packing and unpacking work.
"""
)
has_torch_function_variadic = _add_docstr(
_has_torch_function_variadic,
r"""Special case of `has_torch_function` that skips tuple creation.
This uses the METH_FASTCALL protocol introduced in Python 3.7; for 3.6
and before it has roughly equivilent performance compared to
`has_torch_function`.
Instead of:
`has_torch_function((a, b))`
call:
`has_torch_function_variadic(a, b)`
which skips unnecessary packing and unpacking work.
"""
)
@functools.lru_cache(None)
def get_overridable_functions() -> Dict[Any, List[Callable]]:
"""List functions that are overridable via __torch_function__
Returns
-------
Dict[Any, List[Callable]]
A dictionary that maps namespaces that contain overridable functions
to functions in that namespace that can be overridden.
"""
overridable_funcs = collections.defaultdict(list)
tested_namespaces = [
(torch, torch.__all__ + dir(torch._C._VariableFunctions)),
(torch.functional, torch.functional.__all__),
(torch.nn.functional, dir(torch.nn.functional)),
(torch.Tensor, dir(torch.Tensor)),
(torch.linalg, dir(torch.linalg)),
(torch.fft, dir(torch.fft)),
(torch.special, dir(torch.special)),
]
for namespace, ns_funcs in tested_namespaces:
for func_name in ns_funcs:
# ignore private functions or functions that are deleted in torch.__init__
if namespace is not torch.Tensor:
if func_name.startswith('_'):
continue
elif func_name.endswith('_'):
continue
elif not func_name[0].islower():
continue
elif func_name == 'unique_dim':
continue
else:
func = getattr(namespace, func_name)
if getattr(object, func_name, None) == func:
continue
if func_name == '__weakref__':
continue
func = getattr(namespace, func_name)
if namespace is torch.Tensor and getattr(object, func_name, None) == func:
continue
# ignore re-exported modules
if isinstance(func, types.ModuleType):
continue
# ignore __future__ imports
if isinstance(func, __future__._Feature):
continue
if not callable(func) and hasattr(func, "__get__"):
if func.__get__ in get_ignored_functions():
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
"but still has an explicit override")
assert func.__get__ not in get_testing_overrides(), msg.format(namespace, func.__name__)
continue
else:
overridable_funcs[func].append(func.__get__)
continue
if not callable(func):
continue
# cannot be overriden by __torch_function__
if func in get_ignored_functions():
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
"but still has an explicit override")
assert func not in get_testing_overrides(), msg.format(namespace, func.__name__)
continue
overridable_funcs[namespace].append(func)
return overridable_funcs
@functools.lru_cache(None)
def _get_tensor_methods() -> Set[Callable]:
""" Returns a set of the overridable methods on ``torch.Tensor`` """
overridable_funcs = get_overridable_functions()
methods = set(overridable_funcs[torch.Tensor])
return methods
def is_tensor_method_or_property(func: Callable) -> bool:
"""
Returns True if the function passed in is a handler for a
method or property belonging to ``torch.Tensor``, as passed
into ``__torch_function__``.
.. note::
For properties, their ``__get__`` method must be passed in.
This may be needed, in particular, for the following reasons:
1. Methods/properties sometimes don't contain a `__module__` slot.
2. They require that the first passed-in argument is an instance
of ``torch.Tensor``.
Examples
--------
>>> is_tensor_method_or_property(torch.Tensor.add)
True
>>> is_tensor_method_or_property(torch.add)
False
"""
return func in _get_tensor_methods() or func.__name__ == "__get__"
def is_tensor_like(inp):
"""
Returns ``True`` if the passed-in input is a Tensor-like.
Currently, this occurs whenever there's a ``__torch_function__``
attribute on the type of the input.
Examples
--------
A subclass of tensor is generally a Tensor-like.
>>> class SubTensor(torch.Tensor): ...
>>> is_tensor_like(SubTensor([0]))
True
Built-in or user types aren't usually Tensor-like.
>>> is_tensor_like(6)
False
>>> is_tensor_like(None)
False
>>> class NotATensor: ...
>>> is_tensor_like(NotATensor())
False
But, they can be made Tensor-like by implementing __torch_function__.
>>> class TensorLike:
... def __torch_function__(self, func, types, args, kwargs):
... return -1
>>> is_tensor_like(TensorLike())
True
"""
return type(inp) is torch.Tensor or hasattr(type(inp), "__torch_function__")