mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE] Use torch.special.expm1 (#141518)
Instead of `torch.exp(x)-1`, as suggested by TorchFix Pull Request resolved: https://github.com/pytorch/pytorch/pull/141518 Approved by: https://github.com/kit1980
This commit is contained in:
parent
dcd16bdc21
commit
f2d388eddd
|
|
@ -231,8 +231,8 @@ class ContinuousBernoulli(ExponentialFamily):
|
|||
cut_nat_params = torch.where(
|
||||
out_unst_reg, x, (self._lims[0] - 0.5) * torch.ones_like(x)
|
||||
)
|
||||
log_norm = torch.log(torch.abs(torch.exp(cut_nat_params) - 1.0)) - torch.log(
|
||||
torch.abs(cut_nat_params)
|
||||
)
|
||||
log_norm = torch.log(
|
||||
torch.abs(torch.special.expm1(cut_nat_params))
|
||||
) - torch.log(torch.abs(cut_nat_params))
|
||||
taylor = 0.5 * x + torch.pow(x, 2) / 24.0 - torch.pow(x, 4) / 2880.0
|
||||
return torch.where(out_unst_reg, log_norm, taylor)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user