mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Alias for i0 to special namespace (#59141)
Summary: See https://github.com/pytorch/pytorch/issues/50345 cc: mruberry kshitij12345 Pull Request resolved: https://github.com/pytorch/pytorch/pull/59141 Reviewed By: ngimel Differential Revision: D28784097 Pulled By: mruberry fbshipit-source-id: 9b61a21906ef337292686fd40e328502a79e6f09
This commit is contained in:
parent
059a717c9e
commit
44c20ce676
|
|
@ -376,7 +376,6 @@ _(aten, hspmm) \
|
|||
_(aten, hsplit) \
|
||||
_(aten, hstack) \
|
||||
_(aten, hypot) \
|
||||
_(aten, i0) \
|
||||
_(aten, i0_) \
|
||||
_(aten, igamma) \
|
||||
_(aten, igamma_) \
|
||||
|
|
|
|||
|
|
@ -336,6 +336,8 @@ namespace c10 {
|
|||
_(aten, special_expm1) \
|
||||
_(aten, exp2) \
|
||||
_(aten, special_exp2) \
|
||||
_(aten, i0) \
|
||||
_(aten, special_i0) \
|
||||
_(aten, special_i0e) \
|
||||
_(aten, special_i1) \
|
||||
_(aten, special_i1e) \
|
||||
|
|
|
|||
|
|
@ -412,6 +412,10 @@ Tensor special_erfc(const Tensor& self) { return self.erfc(); }
|
|||
Tensor& special_erfinv_out(const Tensor& self, Tensor& result) { return at::erfinv_out(result, self); }
|
||||
Tensor special_erfinv(const Tensor& self) { return self.erfinv(); }
|
||||
|
||||
// special_i0, alias for i0
|
||||
Tensor& special_i0_out(const Tensor& self, Tensor& result) { return at::i0_out(result, self); }
|
||||
Tensor special_i0(const Tensor& self) { return self.i0(); }
|
||||
|
||||
namespace {
|
||||
|
||||
inline Tensor calc_ndtr(const Tensor& self) {
|
||||
|
|
|
|||
|
|
@ -9405,6 +9405,14 @@
|
|||
dispatch:
|
||||
CompositeExplicitAutograd: special_xlog1py_out
|
||||
|
||||
- func: special_i0(Tensor self) -> Tensor
|
||||
python_module: special
|
||||
variants: function
|
||||
|
||||
- func: special_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: special
|
||||
variants: function
|
||||
|
||||
- func: special_i0e(Tensor self) -> Tensor
|
||||
python_module: special
|
||||
variants: function
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ Functions
|
|||
.. autofunction:: expm1
|
||||
.. autofunction:: exp2
|
||||
.. autofunction:: gammaln
|
||||
.. autofunction:: i0
|
||||
.. autofunction:: i0e
|
||||
.. autofunction:: i1
|
||||
.. autofunction:: i1e
|
||||
|
|
|
|||
|
|
@ -3866,24 +3866,8 @@ add_docstr(torch.i0,
|
|||
r"""
|
||||
i0(input, *, out=None) -> Tensor
|
||||
|
||||
Computes the zeroth order modified Bessel function of the first kind for each element of :attr:`input`.
|
||||
|
||||
.. math::
|
||||
\text{out}_{i} = I_0(\text{input}_{i}) = \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!)^2}
|
||||
|
||||
""" + r"""
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.i0(torch.arange(5, dtype=torch.float32))
|
||||
tensor([ 1.0000, 1.2661, 2.2796, 4.8808, 11.3019])
|
||||
|
||||
""".format(**common_args))
|
||||
Alias for :func:`torch.special.i0`.
|
||||
""")
|
||||
|
||||
add_docstr(torch.igamma,
|
||||
r"""
|
||||
|
|
|
|||
|
|
@ -182,6 +182,22 @@ inline Tensor& xlog1py_out(Tensor& result, const Tensor& self, const Scalar& oth
|
|||
return torch::special_xlog1py_out(result, self, other);
|
||||
}
|
||||
|
||||
/// Computes the zeroth order modified Bessel function of the first kind of input, elementwise
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.i0
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// auto t = torch::randn(128, dtype=kDouble);
|
||||
/// torch::special::i0(t);
|
||||
/// ```
|
||||
inline Tensor i0(const Tensor& self) {
|
||||
return torch::special_i0(self);
|
||||
}
|
||||
|
||||
inline Tensor& i0_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_i0_out(result, self);
|
||||
}
|
||||
|
||||
/// Computes the area under the standard Gaussian probability density function,
|
||||
/// integrated from minus infinity to :attr:`input`, elementwise
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.ndtr
|
||||
|
|
@ -195,7 +211,7 @@ inline Tensor ndtr(const Tensor& self) {
|
|||
return torch::special_ndtr(self);
|
||||
}
|
||||
|
||||
inline Tensor ndtr_out(Tensor& result, const Tensor& self) {
|
||||
inline Tensor& ndtr_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_ndtr_out(result, self);
|
||||
}
|
||||
|
||||
|
|
@ -211,7 +227,7 @@ inline Tensor i0e(const Tensor& self) {
|
|||
return torch::special_i0e(self);
|
||||
}
|
||||
|
||||
inline Tensor i0e_out(Tensor& result, const Tensor& self) {
|
||||
inline Tensor& i0e_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_i0e_out(result, self);
|
||||
}
|
||||
|
||||
|
|
@ -227,7 +243,7 @@ inline Tensor i1(const Tensor& self) {
|
|||
return torch::special_i1(self);
|
||||
}
|
||||
|
||||
inline Tensor i1_out(Tensor& result, const Tensor& self) {
|
||||
inline Tensor& i1_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_i1_out(result, self);
|
||||
}
|
||||
|
||||
|
|
@ -243,7 +259,7 @@ inline Tensor i1e(const Tensor& self) {
|
|||
return torch::special_i1e(self);
|
||||
}
|
||||
|
||||
inline Tensor i1e_out(Tensor& result, const Tensor& self) {
|
||||
inline Tensor& i1e_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_i1e_out(result, self);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -117,6 +117,7 @@ const std::unordered_map<Symbol, Symbol>& getOperatorAliasMap() {
|
|||
{aten::special_exp2, aten::exp2},
|
||||
{aten::special_expm1, aten::expm1},
|
||||
{aten::special_logit, aten::logit},
|
||||
{aten::special_i0, aten::i0},
|
||||
{aten::orgqr, aten::linalg_householder_product},
|
||||
{aten::special_gammaln, aten::lgamma}};
|
||||
return alias_map;
|
||||
|
|
|
|||
|
|
@ -866,6 +866,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.special.expm1: lambda input: -1,
|
||||
torch.special.expit: lambda input: -1,
|
||||
torch.special.gammaln: lambda input: -1,
|
||||
torch.special.i0: lambda input: -1,
|
||||
torch.special.i0e: lambda input: -1,
|
||||
torch.special.i1: lambda input: -1,
|
||||
torch.special.i1e: lambda input: -1,
|
||||
|
|
|
|||
|
|
@ -272,6 +272,29 @@ Example::
|
|||
tensor([2.7726, 2.1972, 1.3863])
|
||||
""".format(**common_args))
|
||||
|
||||
i0 = _add_docstr(_special.special_i0,
|
||||
r"""
|
||||
i0(input, *, out=None) -> Tensor
|
||||
|
||||
Computes the zeroth order modified Bessel function of the first kind for each element of :attr:`input`.
|
||||
|
||||
.. math::
|
||||
\text{out}_{i} = I_0(\text{input}_{i}) = \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!)^2}
|
||||
|
||||
""" + r"""
|
||||
Args:
|
||||
input (Tensor): the input tensor
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
|
||||
>>> torch.i0(torch.arange(5, dtype=torch.float32))
|
||||
tensor([ 1.0000, 1.2661, 2.2796, 4.8808, 11.3019])
|
||||
|
||||
""".format(**common_args))
|
||||
|
||||
i0e = _add_docstr(_special.special_i0e,
|
||||
r"""
|
||||
i0e(input, *, out=None) -> Tensor
|
||||
|
|
|
|||
|
|
@ -4899,6 +4899,7 @@ op_db: List[OpInfo] = [
|
|||
UnaryUfuncInfo('i0',
|
||||
ref=np_unary_ufunc_integer_promotion_wrapper(
|
||||
scipy.special.i0) if TEST_SCIPY else _NOTHING,
|
||||
aliases=('special.i0',),
|
||||
decorators=(precisionOverride({torch.bfloat16: 3e-1,
|
||||
torch.float16: 5e-1}),),
|
||||
backward_dtypesIfCPU=floating_types(),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user