mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Implement torch.special.log_ndtr
Implements torch.special.log_ndtr Issue: https://github.com/pytorch/pytorch/issues/50345 TODO: - [x] adding proper reference to scipy implementation - [x] double check if the changes in test/test_unary_ufuncs.py is really necessary - [x] check setting for UnaryUfuncInfo cc: @kshitij12345 @mruberry Pull Request resolved: https://github.com/pytorch/pytorch/pull/74795 Approved by: https://github.com/anjali411
This commit is contained in:
parent
6d36bbde7e
commit
bbf7e159e0
|
|
@ -2160,4 +2160,21 @@ calc_erfcx(T x)
|
|||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Logarithm of Gaussian cumulative distribution function.
|
||||
|
||||
* This implementation of log_ndtr and its helper functions
|
||||
* follow SciPy's implementation
|
||||
* See NOTICE for the licenses.
|
||||
*/
|
||||
template <typename T>
|
||||
static inline C10_HOST_DEVICE T calc_log_ndtr(T x) {
|
||||
T t = x * M_SQRT1_2;
|
||||
if (x < T{-1.0}) {
|
||||
return std::log(calc_erfcx(-t) / 2) - t * t;
|
||||
} else {
|
||||
return std::log1p(-std::erfc(t) / 2);
|
||||
}
|
||||
}
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_POP()
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@ CREATE_UNARY_FLOAT_META_FUNC(special_i0e)
|
|||
CREATE_UNARY_FLOAT_META_FUNC(special_i1)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_i1e)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_ndtri)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(special_log_ndtr)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(sqrt)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(tan)
|
||||
CREATE_UNARY_FLOAT_META_FUNC(tanh)
|
||||
|
|
@ -184,6 +185,7 @@ CREATE_UNARY_TORCH_IMPL_FUNC(special_i0e_out, special_i0e_stub)
|
|||
CREATE_UNARY_TORCH_IMPL_FUNC(special_i1e_out, special_i1e_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_i1_out, special_i1_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_ndtri_out, special_ndtri_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(special_log_ndtr_out, special_log_ndtr_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(sqrt_out, sqrt_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(tan_out, tan_stub)
|
||||
CREATE_UNARY_TORCH_IMPL_FUNC(tanh_out, tanh_stub)
|
||||
|
|
@ -538,7 +540,7 @@ Tensor special_sinc(const Tensor& self) { return self.sinc(); }
|
|||
namespace {
|
||||
|
||||
inline Tensor calc_ndtr(const Tensor& self) {
|
||||
auto x_sqrt_2 = self / std::sqrt(2.);
|
||||
auto x_sqrt_2 = self * M_SQRT1_2;
|
||||
return (1 + at::erf(x_sqrt_2)) * 0.5;
|
||||
}
|
||||
|
||||
|
|
@ -841,6 +843,7 @@ DEFINE_DISPATCH(log1p_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-
|
|||
DEFINE_DISPATCH(log2_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(logical_not_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(special_ndtri_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(special_log_ndtr_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(neg_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(nan_to_num_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_DISPATCH(polygamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ DECLARE_DISPATCH(unary_fn, log10_stub);
|
|||
DECLARE_DISPATCH(unary_fn, log1p_stub);
|
||||
DECLARE_DISPATCH(unary_fn, log2_stub);
|
||||
DECLARE_DISPATCH(unary_fn, special_ndtri_stub);
|
||||
DECLARE_DISPATCH(unary_fn, special_log_ndtr_stub);
|
||||
DECLARE_DISPATCH(unary_fn, neg_stub);
|
||||
|
||||
DECLARE_DISPATCH(unary_fn, reciprocal_stub);
|
||||
|
|
|
|||
|
|
@ -504,6 +504,13 @@ static void ndtri_kernel(TensorIteratorBase& iter) {
|
|||
});
|
||||
}
|
||||
|
||||
static void log_ndtr_kernel(TensorIteratorBase& iter) {
|
||||
TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
|
||||
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "log_ndtr_cpu", [&]() {
|
||||
cpu_kernel(iter, [](scalar_t x) { return calc_log_ndtr(x); });
|
||||
});
|
||||
}
|
||||
|
||||
static void i0e_kernel(TensorIteratorBase& iter) {
|
||||
TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
|
|
@ -641,6 +648,7 @@ REGISTER_DISPATCH(special_entr_stub, &CPU_CAPABILITY::entr_kernel);
|
|||
REGISTER_DISPATCH(frexp_stub, &CPU_CAPABILITY::frexp_kernel);
|
||||
REGISTER_DISPATCH(special_i0e_stub, &CPU_CAPABILITY::i0e_kernel);
|
||||
REGISTER_DISPATCH(special_ndtri_stub, &CPU_CAPABILITY::ndtri_kernel);
|
||||
REGISTER_DISPATCH(special_log_ndtr_stub, &CPU_CAPABILITY::log_ndtr_kernel);
|
||||
REGISTER_DISPATCH(special_i1_stub, &CPU_CAPABILITY::i1_kernel);
|
||||
REGISTER_DISPATCH(special_i1e_stub, &CPU_CAPABILITY::i1e_kernel);
|
||||
REGISTER_DISPATCH(special_erfcx_stub, &CPU_CAPABILITY::erfcx_kernel);
|
||||
|
|
|
|||
|
|
@ -276,6 +276,19 @@ const auto ndtri_string = jiterator_stringify(
|
|||
}
|
||||
); // ndtri_string
|
||||
|
||||
const auto log_ndtr_string = jiterator_stringify(
|
||||
template <typename T>
|
||||
T log_ndtr(T x) {
|
||||
constexpr T SQRT1_2{0.707106781186547524400844362104849039}; // 1/sqrt(2)
|
||||
T t = x * SQRT1_2;
|
||||
if (x < T{-1.0}) {
|
||||
return log(erfcx(-t) / 2) - t * t;
|
||||
} else {
|
||||
return log1p(-erfc(t) / 2);
|
||||
}
|
||||
}
|
||||
); // log_ndtr_string
|
||||
|
||||
const auto gcd_string = jiterator_stringify(
|
||||
template <typename T>
|
||||
T gcd(const T a_in, const T b_in) {
|
||||
|
|
|
|||
|
|
@ -229,6 +229,23 @@ void ndtri_kernel_cuda(TensorIteratorBase& iter) {
|
|||
#endif
|
||||
}
|
||||
|
||||
const char log_ndtr_name[] = "log_ndtr";
|
||||
void log_ndtr_kernel_cuda(TensorIteratorBase& iter) {
|
||||
#if AT_USE_JITERATOR()
|
||||
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "log_ndtr_cuda", [&]() {
|
||||
jitted_gpu_kernel</*name=*/log_ndtr_name,
|
||||
/*return_dtype=*/ scalar_t,
|
||||
/*common_dtype=*/ scalar_t,
|
||||
/*arity=*/ 1>(iter, log_ndtr_string);
|
||||
});
|
||||
#else
|
||||
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "log_ndtr_cuda", [&]() {
|
||||
gpu_kernel(
|
||||
iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_log_ndtr(a); });
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
void erf_kernel_cuda(TensorIteratorBase& iter) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "erf_cuda", [&]() {
|
||||
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
|
||||
|
|
@ -349,6 +366,7 @@ REGISTER_DISPATCH(erfinv_stub, &erfinv_kernel_cuda);
|
|||
REGISTER_DISPATCH(kaiser_window_stub, &kaiser_window_kernel_cuda);
|
||||
REGISTER_DISPATCH(special_entr_stub, &entr_kernel_cuda);
|
||||
REGISTER_DISPATCH(special_ndtri_stub, &ndtri_kernel_cuda);
|
||||
REGISTER_DISPATCH(special_log_ndtr_stub, &log_ndtr_kernel_cuda);
|
||||
REGISTER_DISPATCH(special_erfcx_stub, &erfcx_kernel_cuda);
|
||||
|
||||
} // namespace native
|
||||
|
|
|
|||
|
|
@ -10278,6 +10278,19 @@
|
|||
dispatch:
|
||||
CPU, CUDA: special_ndtri_out
|
||||
|
||||
- func: special_log_ndtr(Tensor self) -> Tensor
|
||||
structured_delegate: special_log_ndtr.out
|
||||
python_module: special
|
||||
variants: function
|
||||
|
||||
- func: special_log_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
structured: True
|
||||
structured_inherits: TensorIteratorBase
|
||||
python_module: special
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU, CUDA: special_log_ndtr_out
|
||||
|
||||
- func: special_expm1(Tensor self) -> Tensor
|
||||
python_module: special
|
||||
variants: function
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ Functions
|
|||
.. autofunction:: multigammaln
|
||||
.. autofunction:: ndtr
|
||||
.. autofunction:: ndtri
|
||||
.. autofunction:: log_ndtr
|
||||
.. autofunction:: round
|
||||
.. autofunction:: sinc
|
||||
.. autofunction:: softmax
|
||||
|
|
|
|||
|
|
@ -1267,11 +1267,25 @@ class TestUnaryUfuncs(TestCase):
|
|||
self.assertEqual(actual, expected)
|
||||
|
||||
range = (-10, 10)
|
||||
|
||||
t = torch.linspace(*range, int(1e4), device=device, dtype=dtype)
|
||||
t = torch.linspace(*range, 1, device=device, dtype=dtype)
|
||||
check_equal(t)
|
||||
|
||||
# NaN, inf, -inf are tested in reference_numerics tests.
|
||||
# Skip testing NaN, inf, -inf since they are tested in reference_numerics tests.
|
||||
info = torch.finfo(dtype)
|
||||
min, max, eps, tiny = info.min, info.max, info.eps, info.tiny
|
||||
t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device)
|
||||
check_equal(t)
|
||||
|
||||
@dtypes(torch.float32, torch.float64)
|
||||
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
|
||||
def test_special_log_ndtr_vs_scipy(self, device, dtype):
|
||||
def check_equal(t):
|
||||
# Test by comparing with scipy
|
||||
actual = torch.special.log_ndtr(t)
|
||||
expected = scipy.special.log_ndtr(t.cpu().numpy())
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
# Skip testing NaN, inf, -inf since they are tested in reference_numerics tests.
|
||||
info = torch.finfo(dtype)
|
||||
min, max, eps, tiny = info.min, info.max, info.eps, info.tiny
|
||||
t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device)
|
||||
|
|
|
|||
|
|
@ -1268,6 +1268,10 @@
|
|||
self: grad * std::sqrt(2 * M_PI) * (result.square() / 2).exp()
|
||||
result: auto_element_wise
|
||||
|
||||
- name: special_log_ndtr(Tensor self) -> Tensor
|
||||
self: grad / std::sqrt(2 * M_PI) * (result + self.pow(2) / 2).neg().exp()
|
||||
result: auto_element_wise
|
||||
|
||||
# [Note: Sometimes view derivatives]
|
||||
# The following situation applies to other operations as well.
|
||||
# TODO: This note is only referenced once by to_dense. Make this
|
||||
|
|
|
|||
|
|
@ -215,6 +215,15 @@ inline Tensor& logsumexp_out(Tensor& result, const Tensor& self, IntArrayRef dim
|
|||
return torch::special_logsumexp_out(result, self, dims, keepdim);
|
||||
}
|
||||
|
||||
/// Computes the argument, x, for which the area under the Gaussian probability density
|
||||
/// function (integrated from minus infinity to x) is equal to input, elementwise.
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.ndtri
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// auto t = torch::rand(128, dtype=kDouble);
|
||||
/// torch::special::ndtri(t);
|
||||
/// ```
|
||||
inline Tensor ndtri(const Tensor& self) {
|
||||
return torch::special_ndtri(self);
|
||||
}
|
||||
|
|
@ -223,6 +232,23 @@ inline Tensor& ndtri_out(Tensor& result, const Tensor& self) {
|
|||
return torch::special_ndtri_out(result, self);
|
||||
}
|
||||
|
||||
/// Computes the log of area under the standard Gaussian probability density function,
|
||||
/// integrated from minus infinity to :attr:`input`, elementwise
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.log_ndtr
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// auto t = torch::randn(128, dtype=kDouble);
|
||||
/// torch::special::log_ndtr(t);
|
||||
/// ```
|
||||
inline Tensor log_ndtr(const Tensor& self) {
|
||||
return torch::special_log_ndtr(self);
|
||||
}
|
||||
|
||||
inline Tensor& log_ndtr_out(Tensor& result, const Tensor& self) {
|
||||
return torch::special_log_ndtr_out(result, self);
|
||||
}
|
||||
|
||||
/// Computes the logit of input, elementwise.
|
||||
/// See https://pytorch.org/docs/master/special.html#torch.special.logit.
|
||||
///
|
||||
|
|
|
|||
|
|
@ -970,6 +970,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.special.multigammaln: lambda input, p: -1,
|
||||
torch.special.ndtri: lambda input: -1,
|
||||
torch.special.ndtr: lambda input: -1,
|
||||
torch.special.log_ndtr: lambda input: -1,
|
||||
torch.special.xlogy: lambda input, other, out=None: -1,
|
||||
torch.special.xlog1py: lambda input, other, out=None: -1,
|
||||
torch.special.zeta: lambda self, other, out=None: -1,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from torch._torch_docs import common_args, multi_dim_common
|
|||
|
||||
__all__ = ['entr', 'psi', 'digamma', 'gammaln', 'polygamma', 'erf', 'erfc', 'erfinv',
|
||||
'erfcx', 'logit', 'logsumexp', 'expit', 'exp2', 'expm1', 'xlog1py', 'xlogy',
|
||||
'i0', 'i0e', 'i1', 'i1e', 'ndtr', 'ndtri', 'log1p', 'sinc', 'round', 'log_softmax',
|
||||
'i0', 'i0e', 'i1', 'i1e', 'ndtr', 'ndtri', 'log_ndtr', 'log1p', 'sinc', 'round', 'log_softmax',
|
||||
'zeta', 'multigammaln', 'gammainc', 'gammaincc', 'softmax']
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
|
@ -547,6 +547,27 @@ Example::
|
|||
tensor([ -inf, -0.6745, 0.0000, 0.6745, inf])
|
||||
""".format(**common_args))
|
||||
|
||||
log_ndtr = _add_docstr(_special.special_log_ndtr,
|
||||
r"""
|
||||
log_ndtr(input, *, out=None) -> Tensor
|
||||
Computes the log of the area under the standard Gaussian probability density function,
|
||||
integrated from minus infinity to :attr:`input`, elementwise.
|
||||
|
||||
.. math::
|
||||
\text{log\_ndtr}(x) = \log\left(\frac{1}{\sqrt{2 \pi}}\int_{-\infty}^{x} e^{-\frac{1}{2}t^2} dt \right)
|
||||
|
||||
""" + r"""
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
>>> torch.special.log_ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
|
||||
tensor([-6.6077 -3.7832 -1.841 -0.6931 -0.1728 -0.023 -0.0014])
|
||||
""".format(**common_args))
|
||||
|
||||
log1p = _add_docstr(_special.special_log1p,
|
||||
r"""
|
||||
log1p(input, *, out=None) -> Tensor
|
||||
|
|
|
|||
|
|
@ -14684,6 +14684,14 @@ op_db: List[OpInfo] = [
|
|||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
safe_casts_outputs=True),
|
||||
UnaryUfuncInfo('special.log_ndtr',
|
||||
aten_name='special_log_ndtr',
|
||||
ref=scipy.special.log_ndtr if TEST_SCIPY else _NOTHING,
|
||||
dtypes=all_types_and(torch.bool),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
safe_casts_outputs=True,
|
||||
),
|
||||
UnaryUfuncInfo('erf',
|
||||
ref=scipy.special.erf if TEST_SCIPY else _NOTHING,
|
||||
aliases=('special.erf', ),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user