Adding multigammaln ref and fix arange (#85153)

Partially based on https://github.com/pytorch/pytorch/pull/83662.

I'll help land this one, as Rob does not work in the PyTorch project
anymore

I removed the data-dependent check for the args, as data dependencies
are bad for many reasons (and it was failing when the input has NaNs).

It also registers arange as a decomposition, and fixes the naming of its
args.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85153
Approved by: https://github.com/mruberry, https://github.com/ngimel
This commit is contained in:
lezcano 2022-09-20 12:40:28 +00:00 committed by PyTorch MergeBot
parent 7a6c4d0c50
commit d17b144e65
9 changed files with 106 additions and 133 deletions

View File

@ -723,8 +723,7 @@ constexpr double QUARTER = 0.25;
}
static inline void mvlgamma_check(const Tensor& self, int64_t p) {
TORCH_CHECK((self > HALF * (p - 1)).all().item<bool>(),
"All elements must be greater than (p-1)/2");
TORCH_CHECK(self.scalar_type() != kBool, "The input tensor may not be a boolean tensor.");
TORCH_CHECK(p >= 1, "p has to be greater than or equal to 1");
}

View File

@ -423,7 +423,6 @@ meta_function_expected_failures = {
torch.median : {f64, i32, i64, u8, i16, bf16, i8, f32},
torch.mode : {f64, i32, i64, f16, u8, i16, bf16, b8, i8, f32},
torch.multinomial : {f64, bf16, f32},
torch.mvlgamma : {f64, i32, i64, u8, i16, bf16, i8, f32},
torch.nn.functional.ctc_loss : {f64, f32},
torch.nn.functional.gaussian_nll_loss : {f64, bf16, f32},
torch.nn.functional.max_pool3d : {f64, f32},
@ -543,7 +542,6 @@ meta_function_device_expected_failures['cuda'] = {
torch.matrix_exp: {f16}, # aten::linalg_matrix_exp
torch.median: {f16}, # aten::median, aten::median.dim_values
torch.multinomial: {f16}, # aten::multinomial, aten::multinomial.out
torch.mvlgamma: {f16}, # aten::_local_scalar_dense, aten::mvlgamma.out
torch.nn.functional.gaussian_nll_loss: {f16}, # aten::_local_scalar_dense
torch.nn.functional.max_pool3d: {bf16, f16}, # aten::max_pool3d_with_indices
torch.nn.functional.max_pool3d_with_indices: {bf16, f16}, # aten::max_pool3d_with_indices
@ -687,8 +685,6 @@ meta_dispatch_expected_failures = {
aten.multilabel_margin_loss_forward.default : {f32, f64},
aten.multinomial.default : {bf16, f32, f64},
aten.multinomial.out : {bf16, f32, f64},
aten.mvlgamma.default : {i8, f64, i64, bf16, f32, i32, i16, u8},
aten.mvlgamma.out : {i8, f64, i64, bf16, f32, i32, i16, u8},
aten.nll_loss2d_forward.default : {bf16, f32, f64},
aten.polar.default : {f32, f64},
aten.rrelu_with_noise.default : {bf16, f32, f64},
@ -745,8 +741,6 @@ meta_dispatch_device_expected_failures['cuda'] = {
aten.multilabel_margin_loss_forward.default: {bf16, f16}, # aten::multilabel_margin_loss_forward
aten.multinomial.default: {f16}, # aten::multinomial
aten.multinomial.out: {f16}, # aten::multinomial.out
aten.mvlgamma.default: {f16}, # aten::_local_scalar_dense
aten.mvlgamma.out: {f16}, # aten::mvlgamma.out
aten.native_group_norm.default: {bf16, f16},
aten.nll_loss2d_forward.default: {f16}, # aten::nll_loss2d_forward
aten.ormqr.default: {f32, f64}, # aten::ormqr

View File

@ -997,9 +997,6 @@ fake_tensor_failures = {
# FakeTensor fallback doesn't work
xfail('segment_reduce', 'lengths'),
xfail('multinomial'),
xfail('mvlgamma', 'mvlgamma_p_1'),
xfail('mvlgamma', 'mvlgamma_p_3'),
xfail('mvlgamma', 'mvlgamma_p_5'),
xfail('cholesky'),
xfail('cholesky_inverse'),
# ASAN failures due to divide by 0

View File

@ -624,16 +624,6 @@ class TestUnaryUfuncs(TestCase):
):
torch.frexp(input, out=(mantissa, exponent))
def test_mvlgamma_argcheck(self, device):
def run_test(d):
input = torch.linspace((d - 2) / 2, 10, 10, device=device)
torch.mvlgamma(input, d)
with self.assertRaisesRegex(
RuntimeError, r"All elements must be greater than \(p-1\)/2"
):
run_test(3)
def test_polygamma_neg(self, device):
with self.assertRaisesRegex(
RuntimeError, r"polygamma\(n, x\) does not support negative n\."

View File

@ -1579,7 +1579,7 @@ def index_add_(
utils.is_weakly_lesser_type(type(alpha), python_type),
lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
)
tensor = torch._prims.mul(tensor, alpha)
tensor = tensor * alpha
idx = (slice(None),) * dim + (index,)
torch.ops.aten.index_put_(x, idx, tensor, accumulate=True)
return x

View File

@ -610,6 +610,10 @@ def lgamma(a):
return prims.lgamma(a)
# alias
mvlgamma = torch.special.multigammaln # type: ignore[has-type]
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def log(a):
return prims.log(a)
@ -3707,37 +3711,6 @@ def empty_like(
)
@overload
def arange(
end: NumberType,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
layout: torch.layout = torch.strided,
pin_memory: bool = False,
requires_grad: bool = False,
) -> TensorLikeType:
pass
@overload
def arange(
start: NumberType,
end: NumberType,
step: NumberType = 1,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
layout: torch.layout = torch.strided,
pin_memory: bool = False,
requires_grad: bool = False,
) -> TensorLikeType:
pass
# See https://github.com/pytorch/pytorch/issues/82364
# @register_decomposition(torch.ops.aten.arange)
# @out_wrapper()
@register_decomposition(
[
torch.ops.aten.arange.default,
@ -3745,9 +3718,10 @@ def arange(
torch.ops.aten.arange.start_step,
]
)
@out_wrapper()
def arange(
a: Optional[NumberType] = None,
b: Optional[NumberType] = None,
start: NumberType = 0,
end: Optional[NumberType] = None,
step: NumberType = 1,
*,
dtype: Optional[torch.dtype] = None,
@ -3756,11 +3730,15 @@ def arange(
pin_memory: bool = False,
requires_grad: bool = False,
) -> TensorLikeType:
assert (a is not None and b is not None) or (a is not None and b is None)
if a is not None and b is not None:
assert not pin_memory
assert layout == torch.strided
# Case: torch.arange(5)
if end is None:
end = start
start = 0
return prims.arange(
a,
b,
start,
end,
step,
dtype=dtype,
device=device,
@ -3768,19 +3746,6 @@ def arange(
# pin_memory=pin_memory,
requires_grad=requires_grad,
)
elif a is not None and b is None:
return prims.arange(
0,
a,
step,
dtype=dtype,
device=device,
# layout=layout,
# pin_memory=pin_memory,
requires_grad=requires_grad,
)
else:
raise AssertionError()
@register_decomposition(torch.ops.aten.linspace)

View File

@ -1,3 +1,4 @@
import math
from typing import Optional
import torch
@ -20,6 +21,7 @@ __all__ = [
"i1",
"i1e",
"logit",
"multigammaln",
"zeta",
]
@ -60,6 +62,18 @@ def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType:
return torch.log(torch.true_divide(self, torch.sub(1, self)))
@register_decomposition(torch.ops.aten.mvlgamma)
@out_wrapper()
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def multigammaln(a: TensorLikeType, p: int) -> TensorLikeType:
c = 0.25 * p * (p - 1) * math.log(math.pi)
b = 0.5 * torch.arange(start=(1 - p), end=1, step=1, dtype=a.dtype, device=a.device)
return torch.sum(torch.lgamma(a.unsqueeze(-1) + b), dim=-1) + c
zeta = _make_elementwise_binary_reference(
prims.zeta, # type: ignore[has-type]
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,

View File

@ -761,9 +761,9 @@ Computes the `multivariate log-gamma function
.. math::
\log(\Gamma_{p}(a)) = C + \displaystyle \sum_{i=1}^{p} \log\left(\Gamma\left(a - \frac{i - 1}{2}\right)\right)
where :math:`C = \log(\pi) \times \frac{p (p - 1)}{4}` and :math:`\Gamma(\cdot)` is the Gamma function.
where :math:`C = \log(\pi) \cdot \frac{p (p - 1)}{4}` and :math:`\Gamma(-)` is the Gamma function.
All elements must be greater than :math:`\frac{p - 1}{2}`, otherwise an error would be thrown.
All elements must be greater than :math:`\frac{p - 1}{2}`, otherwise the behavior is undefiend.
""" + """
Args:

View File

@ -740,7 +740,10 @@ def sample_inputs_arange(op, device, dtype, requires_grad, **kwargs):
for start, end, step in samples:
if start is None:
assert step is None
# Pass end as positional arg
yield SampleInput(end, kwargs={"dtype": dtype, "device": device})
# (Similar to) calling torch.arange(end=3)
yield SampleInput(0, kwargs={"end": end, "dtype": dtype, "device": device})
elif step is None:
yield SampleInput(start, args=(end,), kwargs={"dtype": dtype, "device": device})
else:
@ -5670,20 +5673,15 @@ def skips_mvlgamma(skip_redundant=False):
# To test reference numerics against multiple values of argument `p`,
# we make multiple OpInfo entries with each entry corresponding to different value of p.
# We run the op tests from test_ops.py only for `p=1` to avoid redundancy in testing.
# Class `MvlGammaInfo` already contains the basic information related to the operator,
# it only takes arguments like `domain`, `skips` and `sample_kwargs`, which
# differ between the entries.
class MvlGammaInfo(UnaryUfuncInfo):
def __init__(self, variant_test_name, domain, skips, sample_kwargs):
super(MvlGammaInfo, self).__init__(
'mvlgamma',
def make_mvlgamma_opinfo(variant_test_name, domain, skips, sample_kwargs):
return UnaryUfuncInfo('mvlgamma',
ref=reference_mvlgamma if TEST_SCIPY else None,
aliases=('special.multigammaln',),
variant_test_name=variant_test_name,
domain=domain,
decorators=(precisionOverride({torch.float16: 5e-2}),),
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.half),
dtypesIfCUDA=all_types_and(torch.float16),
sample_inputs_func=sample_inputs_mvlgamma,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
@ -12133,18 +12131,19 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
),
sample_inputs_func=sample_inputs_mode,),
MvlGammaInfo(variant_test_name='mvlgamma_p_1',
make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_1',
domain=(1, None),
skips=skips_mvlgamma() + \
(DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
skips=skips_mvlgamma() + (
DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
dtypes=(torch.float16, torch.int8)),
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
dtypes=(torch.int8,)),),
dtypes=(torch.int8,)),
),
sample_kwargs=lambda device, dtype, input: ({'p': 1}, {'d': 1})),
MvlGammaInfo(variant_test_name='mvlgamma_p_3',
make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_3',
domain=(2, None),
skips=skips_mvlgamma(skip_redundant=True) + (
skips=skips_mvlgamma() + (
DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
dtypes=(torch.float16, torch.int8)),
@ -12152,9 +12151,9 @@ op_db: List[OpInfo] = [
dtypes=(torch.int8,)),
),
sample_kwargs=lambda device, dtype, input: ({'p': 3}, {'d': 3})),
MvlGammaInfo(variant_test_name='mvlgamma_p_5',
make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_5',
domain=(3, None),
skips=skips_mvlgamma(skip_redundant=True) + (
skips=skips_mvlgamma() + (
DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
dtypes=(torch.float16, torch.int8)),
@ -16242,9 +16241,6 @@ python_ref_db = [
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
# See https://github.com/pytorch/pytorch/issues/82364
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
# Prims arange does not follow aten
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',
@ -16486,6 +16482,24 @@ python_ref_db = [
"_refs.lgamma",
torch_opinfo_name="lgamma",
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.multigammaln",
torch_opinfo_name="mvlgamma",
torch_opinfo_variant_name="mvlgamma_p_1",
supports_nvfuser=False,
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.multigammaln",
torch_opinfo_name="mvlgamma",
torch_opinfo_variant_name="mvlgamma_p_3",
supports_nvfuser=False,
),
ElementwiseUnaryPythonRefInfo(
"_refs.special.multigammaln",
torch_opinfo_name="mvlgamma",
torch_opinfo_variant_name="mvlgamma_p_5",
supports_nvfuser=False,
),
ElementwiseUnaryPythonRefInfo(
"_refs.log",
torch_opinfo_name="log",