Implement torch.special.log_ndtr

Implements torch.special.log_ndtr

Issue: https://github.com/pytorch/pytorch/issues/50345

TODO:
- [x] adding proper reference to scipy implementation
- [x] double check if the changes in test/test_unary_ufuncs.py is really necessary
- [x] check setting for UnaryUfuncInfo
cc: @kshitij12345 @mruberry
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74795
Approved by: https://github.com/anjali411
This commit is contained in:
Sherlockk Huang 2022-03-29 23:13:37 +00:00 committed by PyTorch MergeBot
parent 6d36bbde7e
commit bbf7e159e0
14 changed files with 153 additions and 5 deletions

View File

@ -2160,4 +2160,21 @@ calc_erfcx(T x)
}
}
/*
* Logarithm of Gaussian cumulative distribution function.
* This implementation of log_ndtr and its helper functions
* follow SciPy's implementation
* See NOTICE for the licenses.
*/
template <typename T>
static inline C10_HOST_DEVICE T calc_log_ndtr(T x) {
T t = x * M_SQRT1_2;
if (x < T{-1.0}) {
return std::log(calc_erfcx(-t) / 2) - t * t;
} else {
return std::log1p(-std::erfc(t) / 2);
}
}
C10_CLANG_DIAGNOSTIC_POP()

View File

@ -67,6 +67,7 @@ CREATE_UNARY_FLOAT_META_FUNC(special_i0e)
CREATE_UNARY_FLOAT_META_FUNC(special_i1)
CREATE_UNARY_FLOAT_META_FUNC(special_i1e)
CREATE_UNARY_FLOAT_META_FUNC(special_ndtri)
CREATE_UNARY_FLOAT_META_FUNC(special_log_ndtr)
CREATE_UNARY_FLOAT_META_FUNC(sqrt)
CREATE_UNARY_FLOAT_META_FUNC(tan)
CREATE_UNARY_FLOAT_META_FUNC(tanh)
@ -184,6 +185,7 @@ CREATE_UNARY_TORCH_IMPL_FUNC(special_i0e_out, special_i0e_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_i1e_out, special_i1e_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_i1_out, special_i1_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_ndtri_out, special_ndtri_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(special_log_ndtr_out, special_log_ndtr_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(sqrt_out, sqrt_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(tan_out, tan_stub)
CREATE_UNARY_TORCH_IMPL_FUNC(tanh_out, tanh_stub)
@ -538,7 +540,7 @@ Tensor special_sinc(const Tensor& self) { return self.sinc(); }
namespace {
inline Tensor calc_ndtr(const Tensor& self) {
auto x_sqrt_2 = self / std::sqrt(2.);
auto x_sqrt_2 = self * M_SQRT1_2;
return (1 + at::erf(x_sqrt_2)) * 0.5;
}
@ -841,6 +843,7 @@ DEFINE_DISPATCH(log1p_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-
DEFINE_DISPATCH(log2_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(logical_not_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_ndtri_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(special_log_ndtr_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(neg_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(nan_to_num_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(polygamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)

View File

@ -52,6 +52,7 @@ DECLARE_DISPATCH(unary_fn, log10_stub);
DECLARE_DISPATCH(unary_fn, log1p_stub);
DECLARE_DISPATCH(unary_fn, log2_stub);
DECLARE_DISPATCH(unary_fn, special_ndtri_stub);
DECLARE_DISPATCH(unary_fn, special_log_ndtr_stub);
DECLARE_DISPATCH(unary_fn, neg_stub);
DECLARE_DISPATCH(unary_fn, reciprocal_stub);

View File

@ -504,6 +504,13 @@ static void ndtri_kernel(TensorIteratorBase& iter) {
});
}
static void log_ndtr_kernel(TensorIteratorBase& iter) {
TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "log_ndtr_cpu", [&]() {
cpu_kernel(iter, [](scalar_t x) { return calc_log_ndtr(x); });
});
}
static void i0e_kernel(TensorIteratorBase& iter) {
TORCH_INTERNAL_ASSERT(iter.ntensors() == 2);
AT_DISPATCH_FLOATING_TYPES_AND(
@ -641,6 +648,7 @@ REGISTER_DISPATCH(special_entr_stub, &CPU_CAPABILITY::entr_kernel);
REGISTER_DISPATCH(frexp_stub, &CPU_CAPABILITY::frexp_kernel);
REGISTER_DISPATCH(special_i0e_stub, &CPU_CAPABILITY::i0e_kernel);
REGISTER_DISPATCH(special_ndtri_stub, &CPU_CAPABILITY::ndtri_kernel);
REGISTER_DISPATCH(special_log_ndtr_stub, &CPU_CAPABILITY::log_ndtr_kernel);
REGISTER_DISPATCH(special_i1_stub, &CPU_CAPABILITY::i1_kernel);
REGISTER_DISPATCH(special_i1e_stub, &CPU_CAPABILITY::i1e_kernel);
REGISTER_DISPATCH(special_erfcx_stub, &CPU_CAPABILITY::erfcx_kernel);

View File

@ -276,6 +276,19 @@ const auto ndtri_string = jiterator_stringify(
}
); // ndtri_string
const auto log_ndtr_string = jiterator_stringify(
template <typename T>
T log_ndtr(T x) {
constexpr T SQRT1_2{0.707106781186547524400844362104849039}; // 1/sqrt(2)
T t = x * SQRT1_2;
if (x < T{-1.0}) {
return log(erfcx(-t) / 2) - t * t;
} else {
return log1p(-erfc(t) / 2);
}
}
); // log_ndtr_string
const auto gcd_string = jiterator_stringify(
template <typename T>
T gcd(const T a_in, const T b_in) {

View File

@ -229,6 +229,23 @@ void ndtri_kernel_cuda(TensorIteratorBase& iter) {
#endif
}
const char log_ndtr_name[] = "log_ndtr";
void log_ndtr_kernel_cuda(TensorIteratorBase& iter) {
#if AT_USE_JITERATOR()
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "log_ndtr_cuda", [&]() {
jitted_gpu_kernel</*name=*/log_ndtr_name,
/*return_dtype=*/ scalar_t,
/*common_dtype=*/ scalar_t,
/*arity=*/ 1>(iter, log_ndtr_string);
});
#else
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "log_ndtr_cuda", [&]() {
gpu_kernel(
iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_log_ndtr(a); });
});
#endif
}
void erf_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "erf_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
@ -349,6 +366,7 @@ REGISTER_DISPATCH(erfinv_stub, &erfinv_kernel_cuda);
REGISTER_DISPATCH(kaiser_window_stub, &kaiser_window_kernel_cuda);
REGISTER_DISPATCH(special_entr_stub, &entr_kernel_cuda);
REGISTER_DISPATCH(special_ndtri_stub, &ndtri_kernel_cuda);
REGISTER_DISPATCH(special_log_ndtr_stub, &log_ndtr_kernel_cuda);
REGISTER_DISPATCH(special_erfcx_stub, &erfcx_kernel_cuda);
} // namespace native

View File

@ -10278,6 +10278,19 @@
dispatch:
CPU, CUDA: special_ndtri_out
- func: special_log_ndtr(Tensor self) -> Tensor
structured_delegate: special_log_ndtr.out
python_module: special
variants: function
- func: special_log_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
python_module: special
variants: function
dispatch:
CPU, CUDA: special_log_ndtr_out
- func: special_expm1(Tensor self) -> Tensor
python_module: special
variants: function

View File

@ -37,6 +37,7 @@ Functions
.. autofunction:: multigammaln
.. autofunction:: ndtr
.. autofunction:: ndtri
.. autofunction:: log_ndtr
.. autofunction:: round
.. autofunction:: sinc
.. autofunction:: softmax

View File

@ -1267,11 +1267,25 @@ class TestUnaryUfuncs(TestCase):
self.assertEqual(actual, expected)
range = (-10, 10)
t = torch.linspace(*range, int(1e4), device=device, dtype=dtype)
t = torch.linspace(*range, 1, device=device, dtype=dtype)
check_equal(t)
# NaN, inf, -inf are tested in reference_numerics tests.
# Skip testing NaN, inf, -inf since they are tested in reference_numerics tests.
info = torch.finfo(dtype)
min, max, eps, tiny = info.min, info.max, info.eps, info.tiny
t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device)
check_equal(t)
@dtypes(torch.float32, torch.float64)
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
def test_special_log_ndtr_vs_scipy(self, device, dtype):
def check_equal(t):
# Test by comparing with scipy
actual = torch.special.log_ndtr(t)
expected = scipy.special.log_ndtr(t.cpu().numpy())
self.assertEqual(actual, expected)
# Skip testing NaN, inf, -inf since they are tested in reference_numerics tests.
info = torch.finfo(dtype)
min, max, eps, tiny = info.min, info.max, info.eps, info.tiny
t = torch.tensor([min, max, eps, tiny], dtype=dtype, device=device)

View File

@ -1268,6 +1268,10 @@
self: grad * std::sqrt(2 * M_PI) * (result.square() / 2).exp()
result: auto_element_wise
- name: special_log_ndtr(Tensor self) -> Tensor
self: grad / std::sqrt(2 * M_PI) * (result + self.pow(2) / 2).neg().exp()
result: auto_element_wise
# [Note: Sometimes view derivatives]
# The following situation applies to other operations as well.
# TODO: This note is only referenced once by to_dense. Make this

View File

@ -215,6 +215,15 @@ inline Tensor& logsumexp_out(Tensor& result, const Tensor& self, IntArrayRef dim
return torch::special_logsumexp_out(result, self, dims, keepdim);
}
/// Computes the argument, x, for which the area under the Gaussian probability density
/// function (integrated from minus infinity to x) is equal to input, elementwise.
/// See https://pytorch.org/docs/master/special.html#torch.special.ndtri
///
/// Example:
/// ```
/// auto t = torch::rand(128, dtype=kDouble);
/// torch::special::ndtri(t);
/// ```
inline Tensor ndtri(const Tensor& self) {
return torch::special_ndtri(self);
}
@ -223,6 +232,23 @@ inline Tensor& ndtri_out(Tensor& result, const Tensor& self) {
return torch::special_ndtri_out(result, self);
}
/// Computes the log of area under the standard Gaussian probability density function,
/// integrated from minus infinity to :attr:`input`, elementwise
/// See https://pytorch.org/docs/master/special.html#torch.special.log_ndtr
///
/// Example:
/// ```
/// auto t = torch::randn(128, dtype=kDouble);
/// torch::special::log_ndtr(t);
/// ```
inline Tensor log_ndtr(const Tensor& self) {
return torch::special_log_ndtr(self);
}
inline Tensor& log_ndtr_out(Tensor& result, const Tensor& self) {
return torch::special_log_ndtr_out(result, self);
}
/// Computes the logit of input, elementwise.
/// See https://pytorch.org/docs/master/special.html#torch.special.logit.
///

View File

@ -970,6 +970,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.special.multigammaln: lambda input, p: -1,
torch.special.ndtri: lambda input: -1,
torch.special.ndtr: lambda input: -1,
torch.special.log_ndtr: lambda input: -1,
torch.special.xlogy: lambda input, other, out=None: -1,
torch.special.xlog1py: lambda input, other, out=None: -1,
torch.special.zeta: lambda self, other, out=None: -1,

View File

@ -4,7 +4,7 @@ from torch._torch_docs import common_args, multi_dim_common
__all__ = ['entr', 'psi', 'digamma', 'gammaln', 'polygamma', 'erf', 'erfc', 'erfinv',
'erfcx', 'logit', 'logsumexp', 'expit', 'exp2', 'expm1', 'xlog1py', 'xlogy',
'i0', 'i0e', 'i1', 'i1e', 'ndtr', 'ndtri', 'log1p', 'sinc', 'round', 'log_softmax',
'i0', 'i0e', 'i1', 'i1e', 'ndtr', 'ndtri', 'log_ndtr', 'log1p', 'sinc', 'round', 'log_softmax',
'zeta', 'multigammaln', 'gammainc', 'gammaincc', 'softmax']
Tensor = torch.Tensor
@ -547,6 +547,27 @@ Example::
tensor([ -inf, -0.6745, 0.0000, 0.6745, inf])
""".format(**common_args))
log_ndtr = _add_docstr(_special.special_log_ndtr,
r"""
log_ndtr(input, *, out=None) -> Tensor
Computes the log of the area under the standard Gaussian probability density function,
integrated from minus infinity to :attr:`input`, elementwise.
.. math::
\text{log\_ndtr}(x) = \log\left(\frac{1}{\sqrt{2 \pi}}\int_{-\infty}^{x} e^{-\frac{1}{2}t^2} dt \right)
""" + r"""
Args:
{input}
Keyword args:
{out}
Example::
>>> torch.special.log_ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
tensor([-6.6077 -3.7832 -1.841 -0.6931 -0.1728 -0.023 -0.0014])
""".format(**common_args))
log1p = _add_docstr(_special.special_log1p,
r"""
log1p(input, *, out=None) -> Tensor

View File

@ -14684,6 +14684,14 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
safe_casts_outputs=True),
UnaryUfuncInfo('special.log_ndtr',
aten_name='special_log_ndtr',
ref=scipy.special.log_ndtr if TEST_SCIPY else _NOTHING,
dtypes=all_types_and(torch.bool),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
safe_casts_outputs=True,
),
UnaryUfuncInfo('erf',
ref=scipy.special.erf if TEST_SCIPY else _NOTHING,
aliases=('special.erf', ),