Added logsumexp decomposition (#77219)

Pretty simple.

cc: @jansel who mentioned this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77219
Approved by: https://github.com/jansel
This commit is contained in:
Horace He 2022-05-12 02:01:31 +00:00 committed by PyTorch MergeBot
parent f6eb811786
commit c25bdeea26
3 changed files with 34 additions and 8 deletions

View File

@ -99,11 +99,11 @@ native_dropout_cpu(const Tensor& input, double p, c10::optional<bool> train) {
double p1m = 1. - p;
// Check for probability of zero to avoid divide by zero and NaN results
double scale = p1m == 0 ? 0. : 1. / p1m;
mask = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
mask = at::empty_like(input, input.options().dtype(c10::CppTypeToScalarType<bool>::value));
mask.bernoulli_(p1m);
output = input.mul(mask).mul_(scale);
} else {
mask = at::ones_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
mask = at::ones_like(input, input.options().dtype(c10::CppTypeToScalarType<bool>::value));
output = input.clone();
}
return std::make_tuple(output, mask);

View File

@ -699,10 +699,13 @@ def logit_backward(
@register_decomposition(aten.native_dropout)
@pw_cast_for_opmath
def native_dropout_decomposition(input: Tensor, p: float, train: Optional[bool]):
bool_mask = torch.rand_like(input) < p
res = bool_mask * input * float(1.0 / p)
return [res, bool_mask]
def native_dropout(input: Tensor, p: float, train: Optional[bool]):
if train:
bool_mask = torch.rand_like(input) < p
res = bool_mask * input * float(1.0 / p)
return (res, bool_mask)
else:
return (input, torch.ones_like(input, dtype=torch.bool))
# TODO: Correct the type promotion semantics
@ -1045,7 +1048,7 @@ def clamp_max(self: Tensor, max: float):
def _fused_dropout_decomposition(input, p, generator=None):
mask = (torch.rand_like(input) < p).to(dtype=torch.uint8)
res = mask.type_as(input) * input * (1.0 / p)
return [res, mask]
return (res, mask)
# TODO: these logical decomps are buggy for complex inputs
@ -1243,3 +1246,25 @@ def stack(tensors: List[Tensor], dim: int = 0) -> Tensor:
return out.view(result_sizes)
else:
return torch.cat(get_stack_inputs(tensors, wrapped_dim), dim)
def _squeeze_multiple(self: Tensor, dims: List[int]) -> Tensor:
ndim = self.dim()
wrapped_dims = utils.canonicalize_dims(ndim, dims)
assert isinstance(wrapped_dims, tuple)
for idx in range(ndim - 1, -1, -1):
if idx in wrapped_dims:
self = self.squeeze(idx)
return self
@register_decomposition(aten.logsumexp.default)
@pw_cast_for_int_to_real
def logsumexp(self: Tensor, dim: List[int], keepdim: bool = False) -> Tensor:
if self.numel() == 0:
return torch.sum(torch.exp(self), dim, keepdim).log()
maxes = torch.amax(self, dim, keepdim=True)
maxes_squeezed = maxes if keepdim else _squeeze_multiple(maxes, dim)
maxes_squeezed = torch.masked_fill(maxes_squeezed, maxes_squeezed.abs() == float('inf'), 0)
result = torch.sum(torch.exp(self - maxes), dim, keepdim)
return result.log().add(maxes_squeezed)

View File

@ -3126,7 +3126,8 @@ def sample_inputs_logsumexp(self, device, dtype, requires_grad, **kwargs):
inputs = (
((), (0,), True),
((S, S), (1,), True),
((S, S), (1,), False)
((S, S), (1,), False),
((S, S), (-2,), False),
)
samples = []
# Test large inputs to check numerical stability