from collections import namedtuple from functools import update_wrapper from numbers import Number import math import torch import torch.nn.functional as F from torch.autograd import Variable, variable # This follows semantics of numpy.finfo. _Finfo = namedtuple('_Finfo', ['eps', 'tiny']) _FINFO = { torch.HalfStorage: _Finfo(eps=0.00097656, tiny=6.1035e-05), torch.FloatStorage: _Finfo(eps=1.19209e-07, tiny=1.17549e-38), torch.DoubleStorage: _Finfo(eps=2.22044604925e-16, tiny=2.22507385851e-308), torch.cuda.HalfStorage: _Finfo(eps=0.00097656, tiny=6.1035e-05), torch.cuda.FloatStorage: _Finfo(eps=1.19209e-07, tiny=1.17549e-38), torch.cuda.DoubleStorage: _Finfo(eps=2.22044604925e-16, tiny=2.22507385851e-308), } def _finfo(tensor): """ Return floating point info about a `Tensor` or `Variable`: - `.eps` is the smallest number that can be added to 1 without being lost. - `.tiny` is the smallest positive number greater than zero (much smaller than `.eps`). Args: tensor (Tensor or Variable): tensor or variable of floating point data. Returns: _Finfo: a `namedtuple` with fields `.eps` and `.tiny`. """ return _FINFO[tensor.storage_type()] def expand_n(v, n): r""" Cleanly expand float or Tensor or Variable parameters. """ if isinstance(v, Number): return torch.Tensor([v]).expand(n, 1) else: return v.expand(n, *v.size()) def _broadcast_shape(shapes): """ Given a list of tensor sizes, returns the size of the resulting broadcasted tensor. Args: shapes (list of torch.Size): list of tensor sizes """ shape = torch.Size([1]) for s in shapes: shape = torch._C._infer_size(s, shape) return shape def broadcast_all(*values): """ Given a list of values (possibly containing numbers), returns a list where each value is broadcasted based on the following rules: - `torch.Tensor` and `torch.autograd.Variable` instances are broadcasted as per the `broadcasting rules `_ - numbers.Number instances (scalars) are upcast to Variables having the same size and type as the first tensor passed to `values`. If all the values are scalars, then they are upcasted to Variables having size `(1,)`. Args: values (list of `numbers.Number`, `torch.autograd.Variable` or `torch.Tensor`) Raises: ValueError: if any of the values is not a `numbers.Number`, `torch.Tensor` or `torch.autograd.Variable` instance """ values = list(values) scalar_idxs = [i for i in range(len(values)) if isinstance(values[i], Number)] tensor_idxs = [i for i in range(len(values)) if torch.is_tensor(values[i]) or isinstance(values[i], Variable)] if len(scalar_idxs) + len(tensor_idxs) != len(values): raise ValueError('Input arguments must all be instances of numbers.Number, torch.Tensor or ' + 'torch.autograd.Variable.') if tensor_idxs: broadcast_shape = _broadcast_shape([values[i].size() for i in tensor_idxs]) for idx in tensor_idxs: values[idx] = values[idx].expand(broadcast_shape) template = values[tensor_idxs[0]] if len(scalar_idxs) > 0 and not isinstance(template, torch.autograd.Variable): raise ValueError(('Input arguments containing instances of numbers.Number and torch.Tensor ' 'are not currently supported. Use torch.autograd.Variable instead of torch.Tensor')) for idx in scalar_idxs: values[idx] = template.new(template.size()).fill_(values[idx]) else: for idx in scalar_idxs: values[idx] = variable(values[idx]) return values def softmax(tensor): """ Wrapper around softmax to make it work with both Tensors and Variables. TODO: Remove once https://github.com/pytorch/pytorch/issues/2633 is resolved. """ if not isinstance(tensor, Variable): return F.softmax(Variable(tensor), -1).data return F.softmax(tensor, -1) def log_sum_exp(tensor, keepdim=True): """ Numerically stable implementation for the `LogSumExp` operation. The summing is done along the last dimension. Args: tensor (torch.Tensor or torch.autograd.Variable) keepdim (Boolean): Whether to retain the last dimension on summing. """ max_val = tensor.max(dim=-1, keepdim=True)[0] return max_val + (tensor - max_val).exp().sum(dim=-1, keepdim=keepdim).log() def logits_to_probs(logits, is_binary=False): """ Converts a tensor of logits into probabilities. Note that for the binary case, each value denotes log odds, whereas for the multi-dimensional case, the values along the last dimension denote the log probabilities (possibly unnormalized) of the events. """ if is_binary: return F.sigmoid(logits) return softmax(logits) def clamp_probs(probs): eps = _finfo(probs).eps return probs.clamp(min=eps, max=1 - eps) def probs_to_logits(probs, is_binary=False): """ Converts a tensor of probabilities into logits. For the binary case, this denotes the probability of occurrence of the event indexed by `1`. For the multi-dimensional case, the values along the last dimension denote the probabilities of occurrence of each of the events. """ ps_clamped = clamp_probs(probs) if is_binary: return torch.log(ps_clamped) - torch.log1p(-ps_clamped) return torch.log(ps_clamped) class lazy_property(object): """ Used as a decorator for lazy loading of class attributes. This uses a non-data descriptor that calls the wrapped method to compute the property on first call; thereafter replacing the wrapped method into an instance attribute. """ def __init__(self, wrapped): self.wrapped = wrapped update_wrapper(self, wrapped) def __get__(self, instance, obj_type=None): if instance is None: return self value = self.wrapped(instance) setattr(instance, self.wrapped.__name__, value) return value