mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[special] Alias for special.expm1 and special.exp2 (#54670)
Summary: Reference: https://github.com/pytorch/pytorch/issues/50345 Pull Request resolved: https://github.com/pytorch/pytorch/pull/54670 Reviewed By: H-Huang Differential Revision: D27401440 Pulled By: mruberry fbshipit-source-id: 02b1fd0e8ffd3f5a017d6b6b9229b76b92b4b745
This commit is contained in:
parent
75ed6fbd91
commit
c9d0c855f7
|
|
@ -320,7 +320,6 @@ _(aten, equal) \
|
||||||
_(aten, exp) \
|
_(aten, exp) \
|
||||||
_(aten, expand) \
|
_(aten, expand) \
|
||||||
_(aten, expand_as) \
|
_(aten, expand_as) \
|
||||||
_(aten, expm1) \
|
|
||||||
_(aten, exponential) \
|
_(aten, exponential) \
|
||||||
_(aten, eye) \
|
_(aten, eye) \
|
||||||
_(aten, feature_alpha_dropout) \
|
_(aten, feature_alpha_dropout) \
|
||||||
|
|
|
||||||
|
|
@ -319,6 +319,10 @@ namespace c10 {
|
||||||
_(aten, special_erfc) \
|
_(aten, special_erfc) \
|
||||||
_(aten, erfinv) \
|
_(aten, erfinv) \
|
||||||
_(aten, special_erfinv) \
|
_(aten, special_erfinv) \
|
||||||
|
_(aten, expm1) \
|
||||||
|
_(aten, special_expm1) \
|
||||||
|
_(aten, exp2) \
|
||||||
|
_(aten, special_exp2) \
|
||||||
_(aten, has_torch_function) \
|
_(aten, has_torch_function) \
|
||||||
FORALL_ATEN_BASE_SYMBOLS(_) \
|
FORALL_ATEN_BASE_SYMBOLS(_) \
|
||||||
_(onnx, Add) \
|
_(onnx, Add) \
|
||||||
|
|
|
||||||
|
|
@ -299,6 +299,14 @@ Tensor& expm1_out(const Tensor& self, Tensor& result) { return unary_op_impl_flo
|
||||||
Tensor expm1(const Tensor& self) { return unary_op_impl_float(self, expm1_stub); }
|
Tensor expm1(const Tensor& self) { return unary_op_impl_float(self, expm1_stub); }
|
||||||
Tensor& expm1_(Tensor& self) { return unary_op_impl_(self, at::expm1_out); }
|
Tensor& expm1_(Tensor& self) { return unary_op_impl_(self, at::expm1_out); }
|
||||||
|
|
||||||
|
// special_exp2, alias for exp2
|
||||||
|
Tensor& special_exp2_out(const Tensor& self, Tensor& result) { return at::exp2_out(result, self); }
|
||||||
|
Tensor special_exp2(const Tensor& self) { return self.exp2(); }
|
||||||
|
|
||||||
|
// special_expm1, alias for expm1
|
||||||
|
Tensor& special_expm1_out(const Tensor& self, Tensor& result) { return at::expm1_out(result, self); }
|
||||||
|
Tensor special_expm1(const Tensor& self) { return self.expm1(); }
|
||||||
|
|
||||||
Tensor& erf_out(const Tensor& self, Tensor& result) { return unary_op_impl_float_out(result, self, erf_stub); }
|
Tensor& erf_out(const Tensor& self, Tensor& result) { return unary_op_impl_float_out(result, self, erf_stub); }
|
||||||
Tensor erf(const Tensor& self) { return unary_op_impl_float(self, erf_stub); }
|
Tensor erf(const Tensor& self) { return unary_op_impl_float(self, erf_stub); }
|
||||||
Tensor& erf_(Tensor& self) { return unary_op_impl_(self, at::erf_out); }
|
Tensor& erf_(Tensor& self) { return unary_op_impl_(self, at::erf_out); }
|
||||||
|
|
@ -311,12 +319,15 @@ Tensor& erfinv_out(const Tensor& self, Tensor& result) { return unary_op_impl_fl
|
||||||
Tensor erfinv(const Tensor& self) { return unary_op_impl_float(self, erfinv_stub); }
|
Tensor erfinv(const Tensor& self) { return unary_op_impl_float(self, erfinv_stub); }
|
||||||
Tensor& erfinv_(Tensor& self) { return unary_op_impl_(self, at::erfinv_out); }
|
Tensor& erfinv_(Tensor& self) { return unary_op_impl_(self, at::erfinv_out); }
|
||||||
|
|
||||||
|
// special_erf, alias for erf
|
||||||
Tensor& special_erf_out(const Tensor& self, Tensor& result) { return at::erf_out(result, self); }
|
Tensor& special_erf_out(const Tensor& self, Tensor& result) { return at::erf_out(result, self); }
|
||||||
Tensor special_erf(const Tensor& self) { return self.erf(); }
|
Tensor special_erf(const Tensor& self) { return self.erf(); }
|
||||||
|
|
||||||
|
// special_erfc, alias for erfc
|
||||||
Tensor& special_erfc_out(const Tensor& self, Tensor& result) { return at::erfc_out(result, self); }
|
Tensor& special_erfc_out(const Tensor& self, Tensor& result) { return at::erfc_out(result, self); }
|
||||||
Tensor special_erfc(const Tensor& self) { return self.erfc(); }
|
Tensor special_erfc(const Tensor& self) { return self.erfc(); }
|
||||||
|
|
||||||
|
// special_erfinv, alias for erfinv
|
||||||
Tensor& special_erfinv_out(const Tensor& self, Tensor& result) { return at::erfinv_out(result, self); }
|
Tensor& special_erfinv_out(const Tensor& self, Tensor& result) { return at::erfinv_out(result, self); }
|
||||||
Tensor special_erfinv(const Tensor& self) { return self.erfinv(); }
|
Tensor special_erfinv(const Tensor& self) { return self.erfinv(); }
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8366,6 +8366,22 @@
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU, CUDA: special_entr_out
|
CPU, CUDA: special_entr_out
|
||||||
|
|
||||||
|
- func: special_expm1(Tensor self) -> Tensor
|
||||||
|
python_module: special
|
||||||
|
variants: function
|
||||||
|
|
||||||
|
- func: special_expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||||
|
python_module: special
|
||||||
|
variants: function
|
||||||
|
|
||||||
|
- func: special_exp2(Tensor self) -> Tensor
|
||||||
|
python_module: special
|
||||||
|
variants: function
|
||||||
|
|
||||||
|
- func: special_exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||||
|
python_module: special
|
||||||
|
variants: function
|
||||||
|
|
||||||
- func: special_gammaln(Tensor self) -> Tensor
|
- func: special_gammaln(Tensor self) -> Tensor
|
||||||
python_module: special
|
python_module: special
|
||||||
variants: function
|
variants: function
|
||||||
|
|
|
||||||
|
|
@ -22,4 +22,6 @@ Functions
|
||||||
.. autofunction:: erf
|
.. autofunction:: erf
|
||||||
.. autofunction:: erfc
|
.. autofunction:: erfc
|
||||||
.. autofunction:: erfinv
|
.. autofunction:: erfinv
|
||||||
|
.. autofunction:: expm1
|
||||||
|
.. autofunction:: exp2
|
||||||
.. autofunction:: gammaln
|
.. autofunction:: gammaln
|
||||||
|
|
@ -3082,7 +3082,10 @@ Computes the base two exponential function of :attr:`input`.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
y_{i} = 2^{x_{i}}
|
y_{i} = 2^{x_{i}}
|
||||||
|
|
||||||
|
.. note:: Alias for :func:`torch.special.exp2`.
|
||||||
""" + r"""
|
""" + r"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
{input}
|
{input}
|
||||||
|
|
||||||
|
|
@ -3104,6 +3107,10 @@ of :attr:`input`.
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
y_{i} = e^{x_{i}} - 1
|
y_{i} = e^{x_{i}} - 1
|
||||||
|
|
||||||
|
.. note:: This function provides greater precision than exp(x) - 1 for small values of x.
|
||||||
|
|
||||||
|
.. note:: Alias for :func:`torch.special.expm1`.
|
||||||
""" + r"""
|
""" + r"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
|
|
@ -65,4 +65,28 @@ inline Tensor erfinv(const Tensor& self) {
|
||||||
return torch::special_erfinv(self);
|
return torch::special_erfinv(self);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Computes the base two exponential function of :attr:`input`, elementwise
|
||||||
|
/// See https://pytorch.org/docs/master/special.html#torch.special.exp2.
|
||||||
|
///
|
||||||
|
/// Example:
|
||||||
|
/// ```
|
||||||
|
/// auto t = torch::randn(128, dtype=kDouble);
|
||||||
|
/// torch::special::exp2(t);
|
||||||
|
/// ```
|
||||||
|
inline Tensor exp2(const Tensor& self) {
|
||||||
|
return torch::special_exp2(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Computes the exponential of the elements minus 1, elementwise
|
||||||
|
/// See https://pytorch.org/docs/master/special.html#torch.special.expm1.
|
||||||
|
///
|
||||||
|
/// Example:
|
||||||
|
/// ```
|
||||||
|
/// auto t = torch::randn(128, dtype=kDouble);
|
||||||
|
/// torch::special::expm1(t);
|
||||||
|
/// ```
|
||||||
|
inline Tensor expm1(const Tensor& self) {
|
||||||
|
return torch::special_expm1(self);
|
||||||
|
}
|
||||||
|
|
||||||
}} // torch::special
|
}} // torch::special
|
||||||
|
|
|
||||||
|
|
@ -90,6 +90,8 @@ const std::unordered_map<Symbol, Symbol>& getOperatorAliasMap() {
|
||||||
{aten::special_erf, aten::erf},
|
{aten::special_erf, aten::erf},
|
||||||
{aten::special_erfc, aten::erfc},
|
{aten::special_erfc, aten::erfc},
|
||||||
{aten::special_erfinv, aten::erfinv},
|
{aten::special_erfinv, aten::erfinv},
|
||||||
|
{aten::special_exp2, aten::exp2},
|
||||||
|
{aten::special_expm1, aten::expm1},
|
||||||
{aten::orgqr, aten::linalg_householder_product},
|
{aten::orgqr, aten::linalg_householder_product},
|
||||||
{aten::special_gammaln, aten::lgamma}};
|
{aten::special_gammaln, aten::lgamma}};
|
||||||
return alias_map;
|
return alias_map;
|
||||||
|
|
|
||||||
|
|
@ -838,6 +838,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||||
torch.special.erf: lambda input: -1,
|
torch.special.erf: lambda input: -1,
|
||||||
torch.special.erfc: lambda input: -1,
|
torch.special.erfc: lambda input: -1,
|
||||||
torch.special.erfinv: lambda input: -1,
|
torch.special.erfinv: lambda input: -1,
|
||||||
|
torch.special.exp2: lambda input: -1,
|
||||||
|
torch.special.expm1: lambda input: -1,
|
||||||
torch.special.gammaln: lambda input: -1,
|
torch.special.gammaln: lambda input: -1,
|
||||||
torch.t: lambda input: -1,
|
torch.t: lambda input: -1,
|
||||||
torch.take: lambda input, index: -1,
|
torch.take: lambda input, index: -1,
|
||||||
|
|
|
||||||
|
|
@ -121,3 +121,52 @@ Example::
|
||||||
>>> torch.erfinv(torch.tensor([0, 0.5, -1.]))
|
>>> torch.erfinv(torch.tensor([0, 0.5, -1.]))
|
||||||
tensor([ 0.0000, 0.4769, -inf])
|
tensor([ 0.0000, 0.4769, -inf])
|
||||||
""".format(**common_args))
|
""".format(**common_args))
|
||||||
|
|
||||||
|
exp2 = _add_docstr(_special.special_exp2,
|
||||||
|
r"""
|
||||||
|
exp2(input, *, out=None) -> Tensor
|
||||||
|
|
||||||
|
Computes the base two exponential function of :attr:`input`.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
y_{i} = 2^{x_{i}}
|
||||||
|
|
||||||
|
""" + r"""
|
||||||
|
Args:
|
||||||
|
{input}
|
||||||
|
|
||||||
|
Keyword args:
|
||||||
|
{out}
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> torch.special.exp2(torch.tensor([0, math.log2(2.), 3, 4]))
|
||||||
|
tensor([ 1., 2., 8., 16.])
|
||||||
|
""".format(**common_args))
|
||||||
|
|
||||||
|
expm1 = _add_docstr(_special.special_expm1,
|
||||||
|
r"""
|
||||||
|
expm1(input, *, out=None) -> Tensor
|
||||||
|
|
||||||
|
Computes the exponential of the elements minus 1
|
||||||
|
of :attr:`input`.
|
||||||
|
|
||||||
|
..
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
y_{i} = e^{x_{i}} - 1
|
||||||
|
|
||||||
|
.. note:: This function provides greater precision than exp(x) - 1 for small values of x.
|
||||||
|
|
||||||
|
""" + r"""
|
||||||
|
Args:
|
||||||
|
{input}
|
||||||
|
|
||||||
|
Keyword args:
|
||||||
|
{out}
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> torch.expm1(torch.tensor([0, math.log(2.)]))
|
||||||
|
tensor([ 0., 1.])
|
||||||
|
""".format(**common_args))
|
||||||
|
|
|
||||||
|
|
@ -3334,12 +3334,14 @@ op_db: List[OpInfo] = [
|
||||||
dtypesIfCUDA=floating_types_and(torch.float16),
|
dtypesIfCUDA=floating_types_and(torch.float16),
|
||||||
assert_autodiffed=True),
|
assert_autodiffed=True),
|
||||||
UnaryUfuncInfo('exp2',
|
UnaryUfuncInfo('exp2',
|
||||||
|
aliases=('special.exp2', ),
|
||||||
ref=np_unary_ufunc_integer_promotion_wrapper(np.exp2),
|
ref=np_unary_ufunc_integer_promotion_wrapper(np.exp2),
|
||||||
dtypes=all_types_and(torch.bool, torch.half),
|
dtypes=all_types_and(torch.bool, torch.half),
|
||||||
dtypesIfCPU=all_types_and(torch.bool, torch.half),
|
dtypesIfCPU=all_types_and(torch.bool, torch.half),
|
||||||
dtypesIfCUDA=all_types_and(torch.bool, torch.half),
|
dtypesIfCUDA=all_types_and(torch.bool, torch.half),
|
||||||
safe_casts_outputs=True),
|
safe_casts_outputs=True),
|
||||||
UnaryUfuncInfo('expm1',
|
UnaryUfuncInfo('expm1',
|
||||||
|
aliases=('special.expm1', ),
|
||||||
ref=np_unary_ufunc_integer_promotion_wrapper(np.expm1),
|
ref=np_unary_ufunc_integer_promotion_wrapper(np.expm1),
|
||||||
dtypes=all_types_and(torch.bool, torch.half),
|
dtypes=all_types_and(torch.bool, torch.half),
|
||||||
dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16),
|
dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user