mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the `project-excludes` field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: 0 errors (4,263 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164748 Approved by: https://github.com/oulgen
253 lines
9.9 KiB
Python
253 lines
9.9 KiB
Python
# mypy: allow-untyped-defs
|
|
import math
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.distributions import constraints
|
|
from torch.distributions.distribution import Distribution
|
|
from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv
|
|
from torch.distributions.utils import _standard_normal, lazy_property
|
|
from torch.types import _size
|
|
|
|
|
|
__all__ = ["LowRankMultivariateNormal"]
|
|
|
|
|
|
def _batch_capacitance_tril(W, D):
|
|
r"""
|
|
Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W`
|
|
and a batch of vectors :math:`D`.
|
|
"""
|
|
m = W.size(-1)
|
|
Wt_Dinv = W.mT / D.unsqueeze(-2)
|
|
K = torch.matmul(Wt_Dinv, W).contiguous()
|
|
K.view(-1, m * m)[:, :: m + 1] += 1 # add identity matrix to K
|
|
return torch.linalg.cholesky(K)
|
|
|
|
|
|
def _batch_lowrank_logdet(W, D, capacitance_tril):
|
|
r"""
|
|
Uses "matrix determinant lemma"::
|
|
log|W @ W.T + D| = log|C| + log|D|,
|
|
where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute
|
|
the log determinant.
|
|
"""
|
|
return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum(
|
|
-1
|
|
)
|
|
|
|
|
|
def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
|
|
r"""
|
|
Uses "Woodbury matrix identity"::
|
|
inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
|
|
where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared
|
|
Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`.
|
|
"""
|
|
Wt_Dinv = W.mT / D.unsqueeze(-2)
|
|
Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
|
|
mahalanobis_term1 = (x.pow(2) / D).sum(-1)
|
|
mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
|
|
return mahalanobis_term1 - mahalanobis_term2
|
|
|
|
|
|
class LowRankMultivariateNormal(Distribution):
|
|
r"""
|
|
Creates a multivariate normal distribution with covariance matrix having a low-rank form
|
|
parameterized by :attr:`cov_factor` and :attr:`cov_diag`::
|
|
|
|
covariance_matrix = cov_factor @ cov_factor.T + cov_diag
|
|
|
|
Example:
|
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
|
|
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
|
>>> m = LowRankMultivariateNormal(
|
|
... torch.zeros(2), torch.tensor([[1.0], [0.0]]), torch.ones(2)
|
|
... )
|
|
>>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]`
|
|
tensor([-0.2102, -0.5429])
|
|
|
|
Args:
|
|
loc (Tensor): mean of the distribution with shape `batch_shape + event_shape`
|
|
cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape
|
|
`batch_shape + event_shape + (rank,)`
|
|
cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape
|
|
`batch_shape + event_shape`
|
|
|
|
Note:
|
|
The computation for determinant and inverse of covariance matrix is avoided when
|
|
`cov_factor.shape[1] << cov_factor.shape[0]` thanks to `Woodbury matrix identity
|
|
<https://en.wikipedia.org/wiki/Woodbury_matrix_identity>`_ and
|
|
`matrix determinant lemma <https://en.wikipedia.org/wiki/Matrix_determinant_lemma>`_.
|
|
Thanks to these formulas, we just need to compute the determinant and inverse of
|
|
the small size "capacitance" matrix::
|
|
|
|
capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
|
|
"""
|
|
|
|
# pyrefly: ignore # bad-override
|
|
arg_constraints = {
|
|
"loc": constraints.real_vector,
|
|
"cov_factor": constraints.independent(constraints.real, 2),
|
|
"cov_diag": constraints.independent(constraints.positive, 1),
|
|
}
|
|
support = constraints.real_vector
|
|
has_rsample = True
|
|
|
|
def __init__(
|
|
self,
|
|
loc: Tensor,
|
|
cov_factor: Tensor,
|
|
cov_diag: Tensor,
|
|
validate_args: Optional[bool] = None,
|
|
) -> None:
|
|
if loc.dim() < 1:
|
|
raise ValueError("loc must be at least one-dimensional.")
|
|
event_shape = loc.shape[-1:]
|
|
if cov_factor.dim() < 2:
|
|
raise ValueError(
|
|
"cov_factor must be at least two-dimensional, "
|
|
"with optional leading batch dimensions"
|
|
)
|
|
if cov_factor.shape[-2:-1] != event_shape:
|
|
raise ValueError(
|
|
f"cov_factor must be a batch of matrices with shape {event_shape[0]} x m"
|
|
)
|
|
if cov_diag.shape[-1:] != event_shape:
|
|
raise ValueError(
|
|
f"cov_diag must be a batch of vectors with shape {event_shape}"
|
|
)
|
|
|
|
loc_ = loc.unsqueeze(-1)
|
|
cov_diag_ = cov_diag.unsqueeze(-1)
|
|
try:
|
|
loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors(
|
|
loc_, cov_factor, cov_diag_
|
|
)
|
|
except RuntimeError as e:
|
|
raise ValueError(
|
|
f"Incompatible batch shapes: loc {loc.shape}, cov_factor {cov_factor.shape}, cov_diag {cov_diag.shape}"
|
|
) from e
|
|
self.loc = loc_[..., 0]
|
|
self.cov_diag = cov_diag_[..., 0]
|
|
batch_shape = self.loc.shape[:-1]
|
|
|
|
self._unbroadcasted_cov_factor = cov_factor
|
|
self._unbroadcasted_cov_diag = cov_diag
|
|
self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag)
|
|
super().__init__(batch_shape, event_shape, validate_args=validate_args)
|
|
|
|
def expand(self, batch_shape, _instance=None):
|
|
new = self._get_checked_instance(LowRankMultivariateNormal, _instance)
|
|
batch_shape = torch.Size(batch_shape)
|
|
loc_shape = batch_shape + self.event_shape
|
|
new.loc = self.loc.expand(loc_shape)
|
|
new.cov_diag = self.cov_diag.expand(loc_shape)
|
|
new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:])
|
|
new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor
|
|
new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag
|
|
new._capacitance_tril = self._capacitance_tril
|
|
super(LowRankMultivariateNormal, new).__init__(
|
|
batch_shape, self.event_shape, validate_args=False
|
|
)
|
|
new._validate_args = self._validate_args
|
|
return new
|
|
|
|
@property
|
|
def mean(self) -> Tensor:
|
|
return self.loc
|
|
|
|
@property
|
|
def mode(self) -> Tensor:
|
|
return self.loc
|
|
|
|
@lazy_property
|
|
def variance(self) -> Tensor: # type: ignore[override]
|
|
return (
|
|
self._unbroadcasted_cov_factor.pow(2).sum(-1) + self._unbroadcasted_cov_diag
|
|
).expand(self._batch_shape + self._event_shape)
|
|
|
|
@lazy_property
|
|
def scale_tril(self) -> Tensor:
|
|
# The following identity is used to increase the numerically computation stability
|
|
# for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3):
|
|
# W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2
|
|
# The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
|
|
# hence it is well-conditioned and safe to take Cholesky decomposition.
|
|
n = self._event_shape[0]
|
|
cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1)
|
|
Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze
|
|
K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.mT).contiguous()
|
|
K.view(-1, n * n)[:, :: n + 1] += 1 # add identity matrix to K
|
|
scale_tril = cov_diag_sqrt_unsqueeze * torch.linalg.cholesky(K)
|
|
return scale_tril.expand(
|
|
self._batch_shape + self._event_shape + self._event_shape
|
|
)
|
|
|
|
@lazy_property
|
|
def covariance_matrix(self) -> Tensor:
|
|
covariance_matrix = torch.matmul(
|
|
self._unbroadcasted_cov_factor, self._unbroadcasted_cov_factor.mT
|
|
) + torch.diag_embed(self._unbroadcasted_cov_diag)
|
|
return covariance_matrix.expand(
|
|
self._batch_shape + self._event_shape + self._event_shape
|
|
)
|
|
|
|
@lazy_property
|
|
def precision_matrix(self) -> Tensor:
|
|
# We use "Woodbury matrix identity" to take advantage of low rank form::
|
|
# inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
|
|
# where :math:`C` is the capacitance matrix.
|
|
Wt_Dinv = (
|
|
self._unbroadcasted_cov_factor.mT
|
|
/ self._unbroadcasted_cov_diag.unsqueeze(-2)
|
|
)
|
|
A = torch.linalg.solve_triangular(self._capacitance_tril, Wt_Dinv, upper=False)
|
|
precision_matrix = (
|
|
torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal()) - A.mT @ A
|
|
)
|
|
return precision_matrix.expand(
|
|
self._batch_shape + self._event_shape + self._event_shape
|
|
)
|
|
|
|
def rsample(self, sample_shape: _size = torch.Size()) -> Tensor:
|
|
shape = self._extended_shape(sample_shape)
|
|
W_shape = shape[:-1] + self.cov_factor.shape[-1:]
|
|
eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device)
|
|
eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
|
|
return (
|
|
self.loc
|
|
+ _batch_mv(self._unbroadcasted_cov_factor, eps_W)
|
|
+ self._unbroadcasted_cov_diag.sqrt() * eps_D
|
|
)
|
|
|
|
def log_prob(self, value):
|
|
if self._validate_args:
|
|
self._validate_sample(value)
|
|
diff = value - self.loc
|
|
M = _batch_lowrank_mahalanobis(
|
|
self._unbroadcasted_cov_factor,
|
|
self._unbroadcasted_cov_diag,
|
|
diff,
|
|
self._capacitance_tril,
|
|
)
|
|
log_det = _batch_lowrank_logdet(
|
|
self._unbroadcasted_cov_factor,
|
|
self._unbroadcasted_cov_diag,
|
|
self._capacitance_tril,
|
|
)
|
|
return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M)
|
|
|
|
def entropy(self):
|
|
log_det = _batch_lowrank_logdet(
|
|
self._unbroadcasted_cov_factor,
|
|
self._unbroadcasted_cov_diag,
|
|
self._capacitance_tril,
|
|
)
|
|
H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det)
|
|
if len(self._batch_shape) == 0:
|
|
return H
|
|
else:
|
|
return H.expand(self._batch_shape)
|