[BE] Use torch.special.expm1 (#141518)

Instead of `torch.exp(x)-1`, as suggested by TorchFix

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141518
Approved by: https://github.com/kit1980
This commit is contained in:
Nikita Shulga 2024-11-26 01:47:09 +00:00 committed by PyTorch MergeBot
parent dcd16bdc21
commit f2d388eddd

View File

@ -231,8 +231,8 @@ class ContinuousBernoulli(ExponentialFamily):
cut_nat_params = torch.where(
out_unst_reg, x, (self._lims[0] - 0.5) * torch.ones_like(x)
)
log_norm = torch.log(torch.abs(torch.exp(cut_nat_params) - 1.0)) - torch.log(
torch.abs(cut_nat_params)
)
log_norm = torch.log(
torch.abs(torch.special.expm1(cut_nat_params))
) - torch.log(torch.abs(cut_nat_params))
taylor = 0.5 * x + torch.pow(x, 2) / 24.0 - torch.pow(x, 4) / 2880.0
return torch.where(out_unst_reg, log_norm, taylor)