[BE] Use torch.log1p(x) instead of torch.log(1+x) (#141167)

To fix TOR107 linter violations
Found while trying to migrate PyTorch to latest torchfix
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141167
Approved by: https://github.com/kit1980, https://github.com/Skylion007
This commit is contained in:
Nikita Shulga 2024-11-21 00:36:17 +00:00 committed by PyTorch MergeBot
parent cd942d00dd
commit 2d52f7946b

View File

@ -1275,22 +1275,22 @@ def module_inputs_torch_nn_Softplus(module_info, device, dtype, requires_grad, t
return [
ModuleInput(constructor_input=FunctionInput(),
forward_input=FunctionInput(make_input((10, 20))),
reference_fn=lambda m, p, i: torch.log(1 + torch.exp(i))),
reference_fn=lambda m, p, i: torch.log1p(torch.exp(i))),
ModuleInput(constructor_input=FunctionInput(2),
forward_input=FunctionInput(make_input((10, 20))),
reference_fn=lambda m, p, i: 1. / 2. * torch.log(1 + torch.exp(2 * i)),
reference_fn=lambda m, p, i: 1. / 2. * torch.log1p(torch.exp(2 * i)),
desc='beta'),
ModuleInput(constructor_input=FunctionInput(2, -100),
forward_input=FunctionInput(make_input((10, 20))),
reference_fn=(
lambda m, p, i: ((i * 2) > -100).type_as(i) * i
+ ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log(1 + torch.exp(2 * i))),
+ ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log1p(torch.exp(2 * i))),
desc='beta_threshold'),
ModuleInput(constructor_input=FunctionInput(2, -100),
forward_input=FunctionInput(make_input(())),
reference_fn=(
lambda m, p, i: ((i * 2) > -100).type_as(i) * i
+ ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log(1 + torch.exp(2 * i))),
+ ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log1p(torch.exp(2 * i))),
desc='beta_threshold_scalar'),
ModuleInput(constructor_input=FunctionInput(),
forward_input=FunctionInput(make_input(4)),