mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE] Use torch.log1p(x) instead of torch.log(1+x) (#141167)
To fix TOR107 linter violations Found while trying to migrate PyTorch to latest torchfix Pull Request resolved: https://github.com/pytorch/pytorch/pull/141167 Approved by: https://github.com/kit1980, https://github.com/Skylion007
This commit is contained in:
parent
cd942d00dd
commit
2d52f7946b
|
|
@ -1275,22 +1275,22 @@ def module_inputs_torch_nn_Softplus(module_info, device, dtype, requires_grad, t
|
|||
return [
|
||||
ModuleInput(constructor_input=FunctionInput(),
|
||||
forward_input=FunctionInput(make_input((10, 20))),
|
||||
reference_fn=lambda m, p, i: torch.log(1 + torch.exp(i))),
|
||||
reference_fn=lambda m, p, i: torch.log1p(torch.exp(i))),
|
||||
ModuleInput(constructor_input=FunctionInput(2),
|
||||
forward_input=FunctionInput(make_input((10, 20))),
|
||||
reference_fn=lambda m, p, i: 1. / 2. * torch.log(1 + torch.exp(2 * i)),
|
||||
reference_fn=lambda m, p, i: 1. / 2. * torch.log1p(torch.exp(2 * i)),
|
||||
desc='beta'),
|
||||
ModuleInput(constructor_input=FunctionInput(2, -100),
|
||||
forward_input=FunctionInput(make_input((10, 20))),
|
||||
reference_fn=(
|
||||
lambda m, p, i: ((i * 2) > -100).type_as(i) * i
|
||||
+ ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log(1 + torch.exp(2 * i))),
|
||||
+ ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log1p(torch.exp(2 * i))),
|
||||
desc='beta_threshold'),
|
||||
ModuleInput(constructor_input=FunctionInput(2, -100),
|
||||
forward_input=FunctionInput(make_input(())),
|
||||
reference_fn=(
|
||||
lambda m, p, i: ((i * 2) > -100).type_as(i) * i
|
||||
+ ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log(1 + torch.exp(2 * i))),
|
||||
+ ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log1p(torch.exp(2 * i))),
|
||||
desc='beta_threshold_scalar'),
|
||||
ModuleInput(constructor_input=FunctionInput(),
|
||||
forward_input=FunctionInput(make_input(4)),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user