mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Adding multigammaln ref and fix arange (#85153)
Partially based on https://github.com/pytorch/pytorch/pull/83662. I'll help land this one, as Rob does not work in the PyTorch project anymore I removed the data-dependent check for the args, as data dependencies are bad for many reasons (and it was failing when the input has NaNs). It also registers arange as a decomposition, and fixes the naming of its args. Pull Request resolved: https://github.com/pytorch/pytorch/pull/85153 Approved by: https://github.com/mruberry, https://github.com/ngimel
This commit is contained in:
parent
7a6c4d0c50
commit
d17b144e65
|
|
@ -723,8 +723,7 @@ constexpr double QUARTER = 0.25;
|
|||
}
|
||||
|
||||
static inline void mvlgamma_check(const Tensor& self, int64_t p) {
|
||||
TORCH_CHECK((self > HALF * (p - 1)).all().item<bool>(),
|
||||
"All elements must be greater than (p-1)/2");
|
||||
TORCH_CHECK(self.scalar_type() != kBool, "The input tensor may not be a boolean tensor.");
|
||||
TORCH_CHECK(p >= 1, "p has to be greater than or equal to 1");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -423,7 +423,6 @@ meta_function_expected_failures = {
|
|||
torch.median : {f64, i32, i64, u8, i16, bf16, i8, f32},
|
||||
torch.mode : {f64, i32, i64, f16, u8, i16, bf16, b8, i8, f32},
|
||||
torch.multinomial : {f64, bf16, f32},
|
||||
torch.mvlgamma : {f64, i32, i64, u8, i16, bf16, i8, f32},
|
||||
torch.nn.functional.ctc_loss : {f64, f32},
|
||||
torch.nn.functional.gaussian_nll_loss : {f64, bf16, f32},
|
||||
torch.nn.functional.max_pool3d : {f64, f32},
|
||||
|
|
@ -543,7 +542,6 @@ meta_function_device_expected_failures['cuda'] = {
|
|||
torch.matrix_exp: {f16}, # aten::linalg_matrix_exp
|
||||
torch.median: {f16}, # aten::median, aten::median.dim_values
|
||||
torch.multinomial: {f16}, # aten::multinomial, aten::multinomial.out
|
||||
torch.mvlgamma: {f16}, # aten::_local_scalar_dense, aten::mvlgamma.out
|
||||
torch.nn.functional.gaussian_nll_loss: {f16}, # aten::_local_scalar_dense
|
||||
torch.nn.functional.max_pool3d: {bf16, f16}, # aten::max_pool3d_with_indices
|
||||
torch.nn.functional.max_pool3d_with_indices: {bf16, f16}, # aten::max_pool3d_with_indices
|
||||
|
|
@ -687,8 +685,6 @@ meta_dispatch_expected_failures = {
|
|||
aten.multilabel_margin_loss_forward.default : {f32, f64},
|
||||
aten.multinomial.default : {bf16, f32, f64},
|
||||
aten.multinomial.out : {bf16, f32, f64},
|
||||
aten.mvlgamma.default : {i8, f64, i64, bf16, f32, i32, i16, u8},
|
||||
aten.mvlgamma.out : {i8, f64, i64, bf16, f32, i32, i16, u8},
|
||||
aten.nll_loss2d_forward.default : {bf16, f32, f64},
|
||||
aten.polar.default : {f32, f64},
|
||||
aten.rrelu_with_noise.default : {bf16, f32, f64},
|
||||
|
|
@ -745,8 +741,6 @@ meta_dispatch_device_expected_failures['cuda'] = {
|
|||
aten.multilabel_margin_loss_forward.default: {bf16, f16}, # aten::multilabel_margin_loss_forward
|
||||
aten.multinomial.default: {f16}, # aten::multinomial
|
||||
aten.multinomial.out: {f16}, # aten::multinomial.out
|
||||
aten.mvlgamma.default: {f16}, # aten::_local_scalar_dense
|
||||
aten.mvlgamma.out: {f16}, # aten::mvlgamma.out
|
||||
aten.native_group_norm.default: {bf16, f16},
|
||||
aten.nll_loss2d_forward.default: {f16}, # aten::nll_loss2d_forward
|
||||
aten.ormqr.default: {f32, f64}, # aten::ormqr
|
||||
|
|
|
|||
|
|
@ -997,9 +997,6 @@ fake_tensor_failures = {
|
|||
# FakeTensor fallback doesn't work
|
||||
xfail('segment_reduce', 'lengths'),
|
||||
xfail('multinomial'),
|
||||
xfail('mvlgamma', 'mvlgamma_p_1'),
|
||||
xfail('mvlgamma', 'mvlgamma_p_3'),
|
||||
xfail('mvlgamma', 'mvlgamma_p_5'),
|
||||
xfail('cholesky'),
|
||||
xfail('cholesky_inverse'),
|
||||
# ASAN failures due to divide by 0
|
||||
|
|
|
|||
|
|
@ -624,16 +624,6 @@ class TestUnaryUfuncs(TestCase):
|
|||
):
|
||||
torch.frexp(input, out=(mantissa, exponent))
|
||||
|
||||
def test_mvlgamma_argcheck(self, device):
|
||||
def run_test(d):
|
||||
input = torch.linspace((d - 2) / 2, 10, 10, device=device)
|
||||
torch.mvlgamma(input, d)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"All elements must be greater than \(p-1\)/2"
|
||||
):
|
||||
run_test(3)
|
||||
|
||||
def test_polygamma_neg(self, device):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"polygamma\(n, x\) does not support negative n\."
|
||||
|
|
|
|||
|
|
@ -1579,7 +1579,7 @@ def index_add_(
|
|||
utils.is_weakly_lesser_type(type(alpha), python_type),
|
||||
lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
|
||||
)
|
||||
tensor = torch._prims.mul(tensor, alpha)
|
||||
tensor = tensor * alpha
|
||||
idx = (slice(None),) * dim + (index,)
|
||||
torch.ops.aten.index_put_(x, idx, tensor, accumulate=True)
|
||||
return x
|
||||
|
|
|
|||
|
|
@ -610,6 +610,10 @@ def lgamma(a):
|
|||
return prims.lgamma(a)
|
||||
|
||||
|
||||
# alias
|
||||
mvlgamma = torch.special.multigammaln # type: ignore[has-type]
|
||||
|
||||
|
||||
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
|
||||
def log(a):
|
||||
return prims.log(a)
|
||||
|
|
@ -3707,37 +3711,6 @@ def empty_like(
|
|||
)
|
||||
|
||||
|
||||
@overload
|
||||
def arange(
|
||||
end: NumberType,
|
||||
*,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
layout: torch.layout = torch.strided,
|
||||
pin_memory: bool = False,
|
||||
requires_grad: bool = False,
|
||||
) -> TensorLikeType:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def arange(
|
||||
start: NumberType,
|
||||
end: NumberType,
|
||||
step: NumberType = 1,
|
||||
*,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
layout: torch.layout = torch.strided,
|
||||
pin_memory: bool = False,
|
||||
requires_grad: bool = False,
|
||||
) -> TensorLikeType:
|
||||
pass
|
||||
|
||||
|
||||
# See https://github.com/pytorch/pytorch/issues/82364
|
||||
# @register_decomposition(torch.ops.aten.arange)
|
||||
# @out_wrapper()
|
||||
@register_decomposition(
|
||||
[
|
||||
torch.ops.aten.arange.default,
|
||||
|
|
@ -3745,9 +3718,10 @@ def arange(
|
|||
torch.ops.aten.arange.start_step,
|
||||
]
|
||||
)
|
||||
@out_wrapper()
|
||||
def arange(
|
||||
a: Optional[NumberType] = None,
|
||||
b: Optional[NumberType] = None,
|
||||
start: NumberType = 0,
|
||||
end: Optional[NumberType] = None,
|
||||
step: NumberType = 1,
|
||||
*,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
|
|
@ -3756,31 +3730,22 @@ def arange(
|
|||
pin_memory: bool = False,
|
||||
requires_grad: bool = False,
|
||||
) -> TensorLikeType:
|
||||
assert (a is not None and b is not None) or (a is not None and b is None)
|
||||
if a is not None and b is not None:
|
||||
return prims.arange(
|
||||
a,
|
||||
b,
|
||||
step,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
# layout=layout,
|
||||
# pin_memory=pin_memory,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
elif a is not None and b is None:
|
||||
return prims.arange(
|
||||
0,
|
||||
a,
|
||||
step,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
# layout=layout,
|
||||
# pin_memory=pin_memory,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
else:
|
||||
raise AssertionError()
|
||||
assert not pin_memory
|
||||
assert layout == torch.strided
|
||||
# Case: torch.arange(5)
|
||||
if end is None:
|
||||
end = start
|
||||
start = 0
|
||||
return prims.arange(
|
||||
start,
|
||||
end,
|
||||
step,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
# layout=layout,
|
||||
# pin_memory=pin_memory,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
|
||||
|
||||
@register_decomposition(torch.ops.aten.linspace)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
|
@ -20,6 +21,7 @@ __all__ = [
|
|||
"i1",
|
||||
"i1e",
|
||||
"logit",
|
||||
"multigammaln",
|
||||
"zeta",
|
||||
]
|
||||
|
||||
|
|
@ -60,6 +62,18 @@ def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType:
|
|||
return torch.log(torch.true_divide(self, torch.sub(1, self)))
|
||||
|
||||
|
||||
@register_decomposition(torch.ops.aten.mvlgamma)
|
||||
@out_wrapper()
|
||||
@elementwise_type_promotion_wrapper(
|
||||
type_promoting_args=("a",),
|
||||
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||
)
|
||||
def multigammaln(a: TensorLikeType, p: int) -> TensorLikeType:
|
||||
c = 0.25 * p * (p - 1) * math.log(math.pi)
|
||||
b = 0.5 * torch.arange(start=(1 - p), end=1, step=1, dtype=a.dtype, device=a.device)
|
||||
return torch.sum(torch.lgamma(a.unsqueeze(-1) + b), dim=-1) + c
|
||||
|
||||
|
||||
zeta = _make_elementwise_binary_reference(
|
||||
prims.zeta, # type: ignore[has-type]
|
||||
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||
|
|
|
|||
|
|
@ -761,9 +761,9 @@ Computes the `multivariate log-gamma function
|
|||
.. math::
|
||||
\log(\Gamma_{p}(a)) = C + \displaystyle \sum_{i=1}^{p} \log\left(\Gamma\left(a - \frac{i - 1}{2}\right)\right)
|
||||
|
||||
where :math:`C = \log(\pi) \times \frac{p (p - 1)}{4}` and :math:`\Gamma(\cdot)` is the Gamma function.
|
||||
where :math:`C = \log(\pi) \cdot \frac{p (p - 1)}{4}` and :math:`\Gamma(-)` is the Gamma function.
|
||||
|
||||
All elements must be greater than :math:`\frac{p - 1}{2}`, otherwise an error would be thrown.
|
||||
All elements must be greater than :math:`\frac{p - 1}{2}`, otherwise the behavior is undefiend.
|
||||
""" + """
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -740,7 +740,10 @@ def sample_inputs_arange(op, device, dtype, requires_grad, **kwargs):
|
|||
for start, end, step in samples:
|
||||
if start is None:
|
||||
assert step is None
|
||||
# Pass end as positional arg
|
||||
yield SampleInput(end, kwargs={"dtype": dtype, "device": device})
|
||||
# (Similar to) calling torch.arange(end=3)
|
||||
yield SampleInput(0, kwargs={"end": end, "dtype": dtype, "device": device})
|
||||
elif step is None:
|
||||
yield SampleInput(start, args=(end,), kwargs={"dtype": dtype, "device": device})
|
||||
else:
|
||||
|
|
@ -5670,25 +5673,20 @@ def skips_mvlgamma(skip_redundant=False):
|
|||
# To test reference numerics against multiple values of argument `p`,
|
||||
# we make multiple OpInfo entries with each entry corresponding to different value of p.
|
||||
# We run the op tests from test_ops.py only for `p=1` to avoid redundancy in testing.
|
||||
# Class `MvlGammaInfo` already contains the basic information related to the operator,
|
||||
# it only takes arguments like `domain`, `skips` and `sample_kwargs`, which
|
||||
# differ between the entries.
|
||||
class MvlGammaInfo(UnaryUfuncInfo):
|
||||
def __init__(self, variant_test_name, domain, skips, sample_kwargs):
|
||||
super(MvlGammaInfo, self).__init__(
|
||||
'mvlgamma',
|
||||
ref=reference_mvlgamma if TEST_SCIPY else None,
|
||||
aliases=('special.multigammaln',),
|
||||
variant_test_name=variant_test_name,
|
||||
domain=domain,
|
||||
decorators=(precisionOverride({torch.float16: 5e-2}),),
|
||||
dtypes=all_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.half),
|
||||
sample_inputs_func=sample_inputs_mvlgamma,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
skips=skips,
|
||||
sample_kwargs=sample_kwargs)
|
||||
def make_mvlgamma_opinfo(variant_test_name, domain, skips, sample_kwargs):
|
||||
return UnaryUfuncInfo('mvlgamma',
|
||||
ref=reference_mvlgamma if TEST_SCIPY else None,
|
||||
aliases=('special.multigammaln',),
|
||||
variant_test_name=variant_test_name,
|
||||
domain=domain,
|
||||
decorators=(precisionOverride({torch.float16: 5e-2}),),
|
||||
dtypes=all_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.float16),
|
||||
sample_inputs_func=sample_inputs_mvlgamma,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
skips=skips,
|
||||
sample_kwargs=sample_kwargs)
|
||||
|
||||
|
||||
def sample_inputs_cumulative_ops(op_info, device, dtype, requires_grad, supports_dtype_kwargs=True, **kwargs):
|
||||
|
|
@ -12133,35 +12131,36 @@ op_db: List[OpInfo] = [
|
|||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
|
||||
),
|
||||
sample_inputs_func=sample_inputs_mode,),
|
||||
MvlGammaInfo(variant_test_name='mvlgamma_p_1',
|
||||
domain=(1, None),
|
||||
skips=skips_mvlgamma() + \
|
||||
(DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
|
||||
dtypes=(torch.float16, torch.int8)),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
|
||||
dtypes=(torch.int8,)),),
|
||||
sample_kwargs=lambda device, dtype, input: ({'p': 1}, {'d': 1})),
|
||||
MvlGammaInfo(variant_test_name='mvlgamma_p_3',
|
||||
domain=(2, None),
|
||||
skips=skips_mvlgamma(skip_redundant=True) + (
|
||||
DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
|
||||
dtypes=(torch.float16, torch.int8)),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
|
||||
dtypes=(torch.int8,)),
|
||||
),
|
||||
sample_kwargs=lambda device, dtype, input: ({'p': 3}, {'d': 3})),
|
||||
MvlGammaInfo(variant_test_name='mvlgamma_p_5',
|
||||
domain=(3, None),
|
||||
skips=skips_mvlgamma(skip_redundant=True) + (
|
||||
DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
|
||||
dtypes=(torch.float16, torch.int8)),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
|
||||
dtypes=(torch.int8,)),
|
||||
),
|
||||
sample_kwargs=lambda device, dtype, input: ({'p': 5}, {'d': 5})),
|
||||
make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_1',
|
||||
domain=(1, None),
|
||||
skips=skips_mvlgamma() + (
|
||||
DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
|
||||
dtypes=(torch.float16, torch.int8)),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
|
||||
dtypes=(torch.int8,)),
|
||||
),
|
||||
sample_kwargs=lambda device, dtype, input: ({'p': 1}, {'d': 1})),
|
||||
make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_3',
|
||||
domain=(2, None),
|
||||
skips=skips_mvlgamma() + (
|
||||
DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
|
||||
dtypes=(torch.float16, torch.int8)),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
|
||||
dtypes=(torch.int8,)),
|
||||
),
|
||||
sample_kwargs=lambda device, dtype, input: ({'p': 3}, {'d': 3})),
|
||||
make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_5',
|
||||
domain=(3, None),
|
||||
skips=skips_mvlgamma() + (
|
||||
DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
|
||||
dtypes=(torch.float16, torch.int8)),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
|
||||
dtypes=(torch.int8,)),
|
||||
),
|
||||
sample_kwargs=lambda device, dtype, input: ({'p': 5}, {'d': 5})),
|
||||
BinaryUfuncInfo('ne',
|
||||
ref=np.not_equal,
|
||||
aliases=('not_equal',),
|
||||
|
|
@ -16242,9 +16241,6 @@ python_ref_db = [
|
|||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'),
|
||||
# See https://github.com/pytorch/pytorch/issues/82364
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
|
||||
|
||||
# Prims arange does not follow aten
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta',
|
||||
|
|
@ -16486,6 +16482,24 @@ python_ref_db = [
|
|||
"_refs.lgamma",
|
||||
torch_opinfo_name="lgamma",
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.special.multigammaln",
|
||||
torch_opinfo_name="mvlgamma",
|
||||
torch_opinfo_variant_name="mvlgamma_p_1",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.special.multigammaln",
|
||||
torch_opinfo_name="mvlgamma",
|
||||
torch_opinfo_variant_name="mvlgamma_p_3",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.special.multigammaln",
|
||||
torch_opinfo_name="mvlgamma",
|
||||
torch_opinfo_variant_name="mvlgamma_p_5",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.log",
|
||||
torch_opinfo_name="log",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user