mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Added logsumexp decomposition (#77219)
Pretty simple. cc: @jansel who mentioned this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/77219 Approved by: https://github.com/jansel
This commit is contained in:
parent
f6eb811786
commit
c25bdeea26
|
|
@ -99,11 +99,11 @@ native_dropout_cpu(const Tensor& input, double p, c10::optional<bool> train) {
|
|||
double p1m = 1. - p;
|
||||
// Check for probability of zero to avoid divide by zero and NaN results
|
||||
double scale = p1m == 0 ? 0. : 1. / p1m;
|
||||
mask = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
mask = at::empty_like(input, input.options().dtype(c10::CppTypeToScalarType<bool>::value));
|
||||
mask.bernoulli_(p1m);
|
||||
output = input.mul(mask).mul_(scale);
|
||||
} else {
|
||||
mask = at::ones_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
mask = at::ones_like(input, input.options().dtype(c10::CppTypeToScalarType<bool>::value));
|
||||
output = input.clone();
|
||||
}
|
||||
return std::make_tuple(output, mask);
|
||||
|
|
|
|||
|
|
@ -699,10 +699,13 @@ def logit_backward(
|
|||
|
||||
@register_decomposition(aten.native_dropout)
|
||||
@pw_cast_for_opmath
|
||||
def native_dropout_decomposition(input: Tensor, p: float, train: Optional[bool]):
|
||||
bool_mask = torch.rand_like(input) < p
|
||||
res = bool_mask * input * float(1.0 / p)
|
||||
return [res, bool_mask]
|
||||
def native_dropout(input: Tensor, p: float, train: Optional[bool]):
|
||||
if train:
|
||||
bool_mask = torch.rand_like(input) < p
|
||||
res = bool_mask * input * float(1.0 / p)
|
||||
return (res, bool_mask)
|
||||
else:
|
||||
return (input, torch.ones_like(input, dtype=torch.bool))
|
||||
|
||||
|
||||
# TODO: Correct the type promotion semantics
|
||||
|
|
@ -1045,7 +1048,7 @@ def clamp_max(self: Tensor, max: float):
|
|||
def _fused_dropout_decomposition(input, p, generator=None):
|
||||
mask = (torch.rand_like(input) < p).to(dtype=torch.uint8)
|
||||
res = mask.type_as(input) * input * (1.0 / p)
|
||||
return [res, mask]
|
||||
return (res, mask)
|
||||
|
||||
|
||||
# TODO: these logical decomps are buggy for complex inputs
|
||||
|
|
@ -1243,3 +1246,25 @@ def stack(tensors: List[Tensor], dim: int = 0) -> Tensor:
|
|||
return out.view(result_sizes)
|
||||
else:
|
||||
return torch.cat(get_stack_inputs(tensors, wrapped_dim), dim)
|
||||
|
||||
|
||||
def _squeeze_multiple(self: Tensor, dims: List[int]) -> Tensor:
|
||||
ndim = self.dim()
|
||||
wrapped_dims = utils.canonicalize_dims(ndim, dims)
|
||||
assert isinstance(wrapped_dims, tuple)
|
||||
for idx in range(ndim - 1, -1, -1):
|
||||
if idx in wrapped_dims:
|
||||
self = self.squeeze(idx)
|
||||
return self
|
||||
|
||||
|
||||
@register_decomposition(aten.logsumexp.default)
|
||||
@pw_cast_for_int_to_real
|
||||
def logsumexp(self: Tensor, dim: List[int], keepdim: bool = False) -> Tensor:
|
||||
if self.numel() == 0:
|
||||
return torch.sum(torch.exp(self), dim, keepdim).log()
|
||||
maxes = torch.amax(self, dim, keepdim=True)
|
||||
maxes_squeezed = maxes if keepdim else _squeeze_multiple(maxes, dim)
|
||||
maxes_squeezed = torch.masked_fill(maxes_squeezed, maxes_squeezed.abs() == float('inf'), 0)
|
||||
result = torch.sum(torch.exp(self - maxes), dim, keepdim)
|
||||
return result.log().add(maxes_squeezed)
|
||||
|
|
|
|||
|
|
@ -3126,7 +3126,8 @@ def sample_inputs_logsumexp(self, device, dtype, requires_grad, **kwargs):
|
|||
inputs = (
|
||||
((), (0,), True),
|
||||
((S, S), (1,), True),
|
||||
((S, S), (1,), False)
|
||||
((S, S), (1,), False),
|
||||
((S, S), (-2,), False),
|
||||
)
|
||||
samples = []
|
||||
# Test large inputs to check numerical stability
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user