fix gumbel_softmax

ghstack-source-id: 9fd5822c6a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29674
This commit is contained in:
Vitaly Fedyunin 2019-11-18 08:27:59 -08:00
parent e8de4828ef
commit 0fd5ee87fe

View File

@ -1279,14 +1279,14 @@ def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
if eps != 1e-10:
warnings.warn("`eps` parameter is deprecated and has no effect.")
gumbels = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() # ~Gumbel(0,1)
gumbels = -torch.empty_like(logits, memory_format=torch.contiguous_format).exponential_().log() # ~Gumbel(0,1)
gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
y_soft = gumbels.softmax(dim)
if hard:
# Straight through.
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
y_hard = torch.zeros_like(logits, memory_format=torch.contiguous_format).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
else:
# Reparametrization trick.