mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make torch.lu differentiable for wide/tall inputs + jit (#61564)
Summary: As per title. Pull Request resolved: https://github.com/pytorch/pytorch/pull/61564 Reviewed By: astaff Differential Revision: D30338136 Pulled By: mruberry fbshipit-source-id: f01436fc90980544cdfa270feee16bb3dda21b93
This commit is contained in:
parent
979180cd01
commit
dbcfd7739f
|
|
@ -6809,7 +6809,7 @@
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU, CUDA: ormqr
|
CPU, CUDA: ormqr
|
||||||
|
|
||||||
- func: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor, Tensor, Tensor)
|
- func: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info)
|
||||||
variants: function
|
variants: function
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU, CUDA: _lu_with_info
|
CPU, CUDA: _lu_with_info
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,7 @@ ALLOW_LIST = [
|
||||||
("aten::_svd_helper", datetime.date(2021, 1, 31)),
|
("aten::_svd_helper", datetime.date(2021, 1, 31)),
|
||||||
("aten::_syevd_helper", datetime.date(9999, 1, 1)),
|
("aten::_syevd_helper", datetime.date(9999, 1, 1)),
|
||||||
("aten::_lu_solve_helper", datetime.date(9999, 1, 1)),
|
("aten::_lu_solve_helper", datetime.date(9999, 1, 1)),
|
||||||
|
("aten::_lu_with_info", datetime.date(9999, 1, 1)),
|
||||||
("aten::_linalg_solve_out_helper_", datetime.date(9999, 1, 1)),
|
("aten::_linalg_solve_out_helper_", datetime.date(9999, 1, 1)),
|
||||||
("aten::_cudnn_rnn_flatten_weight", datetime.date(2020, 12, 31)),
|
("aten::_cudnn_rnn_flatten_weight", datetime.date(2020, 12, 31)),
|
||||||
("aten::_cudnn_rnn", datetime.date(2020, 12, 31)),
|
("aten::_cudnn_rnn", datetime.date(2020, 12, 31)),
|
||||||
|
|
|
||||||
|
|
@ -2865,6 +2865,7 @@ class TestOperatorSignatures(JitTestCase):
|
||||||
'fill_',
|
'fill_',
|
||||||
'hstack',
|
'hstack',
|
||||||
'linalg.multi_dot',
|
'linalg.multi_dot',
|
||||||
|
'lu',
|
||||||
'norm',
|
'norm',
|
||||||
'polygamma',
|
'polygamma',
|
||||||
'special.polygamma',
|
'special.polygamma',
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ all_operators_with_namedtuple_return = {
|
||||||
'frexp', 'lu_unpack', 'histogram', '_fake_quantize_per_tensor_affine_cachemask_tensor_qparams',
|
'frexp', 'lu_unpack', 'histogram', '_fake_quantize_per_tensor_affine_cachemask_tensor_qparams',
|
||||||
'_fused_moving_avg_obs_fq_helper',
|
'_fused_moving_avg_obs_fq_helper',
|
||||||
'_det_lu_based_helper',
|
'_det_lu_based_helper',
|
||||||
|
'_lu_with_info',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -99,6 +100,8 @@ class TestNamedTupleAPI(TestCase):
|
||||||
op(operators=['_det_lu_based_helper'],
|
op(operators=['_det_lu_based_helper'],
|
||||||
input=(), names=('det', 'lu', 'pivs'), hasout=False),
|
input=(), names=('det', 'lu', 'pivs'), hasout=False),
|
||||||
op(operators=['aminmax'], input=(), names=('min', 'max'), hasout=True),
|
op(operators=['aminmax'], input=(), names=('min', 'max'), hasout=True),
|
||||||
|
op(operators=['_lu_with_info'],
|
||||||
|
input=(), names=('LU', 'pivots', 'info'), hasout=False),
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_func(f):
|
def get_func(f):
|
||||||
|
|
|
||||||
|
|
@ -860,8 +860,8 @@
|
||||||
self: zeros_like(self)
|
self: zeros_like(self)
|
||||||
other: zeros_like(other)
|
other: zeros_like(other)
|
||||||
|
|
||||||
- name: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor, Tensor, Tensor)
|
- name: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info)
|
||||||
self: not_implemented("lu_with_info")
|
self: _lu_with_info_backward(grad, self, LU, pivots)
|
||||||
|
|
||||||
- name: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor
|
- name: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor
|
||||||
self, LU_data: lu_solve_backward(grad, self, LU_data, LU_pivots)
|
self, LU_data: lu_solve_backward(grad, self, LU_data, LU_pivots)
|
||||||
|
|
|
||||||
|
|
@ -102,7 +102,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
|
||||||
'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward',
|
'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward',
|
||||||
'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm', '_conj_physical',
|
'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm', '_conj_physical',
|
||||||
'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'trapezoid', 'cumulative_trapezoid',
|
'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'trapezoid', 'cumulative_trapezoid',
|
||||||
'conj_physical_', '_neg_view', '_reshape_alias', '_det_lu_based_helper', 'lu_solve',
|
'conj_physical_', '_neg_view', '_reshape_alias', '_det_lu_based_helper', 'lu_solve', '_lu_with_info',
|
||||||
}
|
}
|
||||||
|
|
||||||
GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
|
GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
|
||||||
|
|
|
||||||
|
|
@ -1,93 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
class _LU(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, self, pivot=True, get_infos=False):
|
|
||||||
LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos))
|
|
||||||
ctx.save_for_backward(LU, pivots)
|
|
||||||
ctx.mark_non_differentiable(pivots, infos)
|
|
||||||
return LU, pivots, infos
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, LU_grad, pivots_grad, infors_grad):
|
|
||||||
"""
|
|
||||||
Here we derive the gradients for the LU decomposition.
|
|
||||||
LIMITATIONS: square inputs of full rank.
|
|
||||||
If not stated otherwise, for tensors A and B,
|
|
||||||
`A B` means the matrix product of A and B.
|
|
||||||
|
|
||||||
Let A^H = (A^T).conj()
|
|
||||||
|
|
||||||
Forward AD:
|
|
||||||
Note that PyTorch returns packed LU, it is a mapping
|
|
||||||
A -> (B:= L + U - I, P), such that A = P L U, and
|
|
||||||
P is a permutation matrix, and is non-differentiable.
|
|
||||||
|
|
||||||
Using B = L + U - I, A = P L U, we get
|
|
||||||
|
|
||||||
dB = dL + dU and (*)
|
|
||||||
P^T dA = dL U + L dU (**)
|
|
||||||
|
|
||||||
By left/right multiplication of (**) with L^{-1}/U^{-1} we get:
|
|
||||||
L^{-1} P^T dA U^{-1} = L^{-1} dL + dU U^{-1}.
|
|
||||||
|
|
||||||
Note that L^{-1} dL is lower-triangular with zero diagonal,
|
|
||||||
and dU U^{-1} is upper-triangular.
|
|
||||||
Define 1_U := triu(ones(n, n)), and 1_L := ones(n, n) - 1_U, so
|
|
||||||
|
|
||||||
L^{-1} dL = 1_L * (L^{-1} P^T dA U^{-1}),
|
|
||||||
dU U^{-1} = 1_U * (L^{-1} P^T dA U^{-1}), where * denotes the Hadamard product.
|
|
||||||
|
|
||||||
Hence we finally get:
|
|
||||||
dL = L 1_L * (L^{-1} P^T dA U^{-1}),
|
|
||||||
dU = 1_U * (L^{-1} P^T dA U^{-1}) U
|
|
||||||
|
|
||||||
Backward AD:
|
|
||||||
The backward sensitivity is then:
|
|
||||||
Tr(B_grad^H dB) = Tr(B_grad^H dL) + Tr(B_grad^H dU) = [1] + [2].
|
|
||||||
|
|
||||||
[1] = Tr(B_grad^H dL) = Tr(B_grad^H L 1_L * (L^{-1} P^T dA U^{-1}))
|
|
||||||
= [using Tr(A (B * C)) = Tr((A * B^T) C)]
|
|
||||||
= Tr((B_grad^H L * 1_L^T) L^{-1} P^T dA U^{-1})
|
|
||||||
= [cyclic property of trace]
|
|
||||||
= Tr(U^{-1} (B_grad^H L * 1_L^T) L^{-1} P^T dA)
|
|
||||||
= Tr((P L^{-H} (L^H B_grad * 1_L) U^{-H})^H dA).
|
|
||||||
Similar, [2] can be rewritten as:
|
|
||||||
[2] = Tr(P L^{-H} (B_grad U^H * 1_U) U^{-H})^H dA, hence
|
|
||||||
Tr(A_grad^H dA) = [1] + [2]
|
|
||||||
= Tr((P L^{-H} (L^H B_grad * 1_L + B_grad U^H * 1_U) U^{-H})^H dA), so
|
|
||||||
A_grad = P L^{-H} (L^H B_grad * 1_L + B_grad U^H * 1_U) U^{-H}.
|
|
||||||
|
|
||||||
In the code below we use the name `LU` instead of `B`, so that there is no confusion
|
|
||||||
in the derivation above between the matrix product and a two-letter variable name.
|
|
||||||
"""
|
|
||||||
LU, pivots = ctx.saved_tensors
|
|
||||||
P, L, U = torch.lu_unpack(LU, pivots)
|
|
||||||
|
|
||||||
# To make sure MyPy infers types right
|
|
||||||
assert (L is not None) and (U is not None) and (P is not None)
|
|
||||||
|
|
||||||
# phi_L = L^H B_grad * 1_L
|
|
||||||
phi_L = (L.transpose(-1, -2).conj() @ LU_grad).tril_()
|
|
||||||
phi_L.diagonal(dim1=-2, dim2=-1).fill_(0.0)
|
|
||||||
# phi_U = B_grad U^H * 1_U
|
|
||||||
phi_U = (LU_grad @ U.transpose(-1, -2).conj()).triu_()
|
|
||||||
phi = phi_L + phi_U
|
|
||||||
|
|
||||||
# using the notation from above plus the variable names, note
|
|
||||||
# A_grad = P L^{-H} phi U^{-H}.
|
|
||||||
# Instead of inverting L and U, we solve two systems of equations, i.e.,
|
|
||||||
# the above expression could be rewritten as
|
|
||||||
# L^H P^T A_grad U^H = phi.
|
|
||||||
# Let X = P^T A_grad U_H, then
|
|
||||||
# X = L^{-H} phi, where L^{-H} is upper triangular, or
|
|
||||||
# X = torch.triangular_solve(phi, L^H)
|
|
||||||
# using the definition of X we see:
|
|
||||||
# X = P^T A_grad U_H => P X = A_grad U_H => U A_grad^H = X^H P^T, so
|
|
||||||
# A_grad = (U^{-1} X^H P^T)^H, or
|
|
||||||
# A_grad = torch.triangular_solve(X^H P^T, U)^H
|
|
||||||
X = torch.triangular_solve(phi, L.transpose(-1, -2).conj(), upper=True).solution
|
|
||||||
A_grad = torch.triangular_solve(X.transpose(-1, -2).conj() @ P.transpose(-1, -2), U, upper=True) \
|
|
||||||
.solution.transpose(-1, -2).conj()
|
|
||||||
|
|
||||||
return A_grad, None, None
|
|
||||||
|
|
@ -453,28 +453,6 @@ class Tensor(torch._C._TensorBase):
|
||||||
if has_torch_function_unary(self):
|
if has_torch_function_unary(self):
|
||||||
return handle_torch_function(Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos)
|
return handle_torch_function(Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos)
|
||||||
|
|
||||||
if not torch._jit_internal.is_scripting():
|
|
||||||
if self.requires_grad:
|
|
||||||
if not (self.size(-2) == self.size(-1) and (self.dtype.is_floating_point) or self.is_complex):
|
|
||||||
raise ValueError(
|
|
||||||
'lu.backward works only with batches of squared full-rank matrices'
|
|
||||||
' of floating or complex types.'
|
|
||||||
)
|
|
||||||
|
|
||||||
from torch._autograd_functions import _LU
|
|
||||||
LU, pivots, infos = _LU.apply(self, pivot, get_infos)
|
|
||||||
if get_infos:
|
|
||||||
return LU, pivots, infos
|
|
||||||
else:
|
|
||||||
return LU, pivots
|
|
||||||
else:
|
|
||||||
if self.requires_grad:
|
|
||||||
raise RuntimeError(
|
|
||||||
'Script and require gradients is not supported at the moment.'
|
|
||||||
'If you just want to do the forward, use .detach()'
|
|
||||||
'on the input before calling the function.'
|
|
||||||
)
|
|
||||||
|
|
||||||
LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos))
|
LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos))
|
||||||
if get_infos:
|
if get_infos:
|
||||||
return LU, pivots, infos
|
return LU, pivots, infos
|
||||||
|
|
|
||||||
|
|
@ -3831,6 +3831,178 @@ Tensor gather_with_keepdimed_indices(const Tensor& input, int64_t dim, const Ten
|
||||||
return out_fw_grad;
|
return out_fw_grad;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Let X in \C^{m \times n}, then its pivoted LU decomposition is
|
||||||
|
// X = P L U, where P is a permutation matrix.
|
||||||
|
//
|
||||||
|
// Useful notation:
|
||||||
|
// Let o denote the elementwise, or Hadamard, product.
|
||||||
|
// k := min(m, n)
|
||||||
|
// 1 := ones(k, k),
|
||||||
|
// 1_U = 1.tril();
|
||||||
|
// 1_L = 1 - 1_U (note the diagonal is zero)
|
||||||
|
// For a matrix A, A^H := A.transpose(-2, -1).conj()
|
||||||
|
//
|
||||||
|
// Below we derive the backward algorithm for the case when m <= n.
|
||||||
|
// The case m > n could be obtained using the same idea.
|
||||||
|
// Since we assume m <= n, the LU decomposition of X could be written as
|
||||||
|
// X = (X1 | X2) = P L (U1 | U2) where X1, U1 in \C^{m \times m}, X2, U2 in \C^{m, n - m}
|
||||||
|
//
|
||||||
|
// Forward AD:
|
||||||
|
//
|
||||||
|
// dX = P dL U + P L dU => [left-multiply P^T]
|
||||||
|
// (P^T dX1 | P^T dX2) = (dL U1 + L dU1 | dL U2 + L dU2) (*)
|
||||||
|
// From (*):
|
||||||
|
// P^T dX1 = dL U1 + L dU1 => [left-multiply by L^{-1}, right-multiply by U1^{-1}]
|
||||||
|
// L^{-1} P^T dX1 U1^{-1} = L^{-1} dL + dU1 U1^{-1} (**).
|
||||||
|
// Note, L is lower-triangular, and so is its inverse, hence L^{-1} dL is lower-triangular.
|
||||||
|
// Also, since the diagonal of L (all ones) is never exposed explicity (packed representation),
|
||||||
|
// the diagonal of dL is zero, and hence diag(L^{-1} dL) = 0.
|
||||||
|
// Assuming that U1 is full-rank, similarly, dU1 U1^{-1} is upper-triangular.
|
||||||
|
// Combining these observations we conclude:
|
||||||
|
//
|
||||||
|
// L^{-1} dL = (L^{-1} P^T dX1 U1^{-1}) o 1_L,
|
||||||
|
// dU1 U1^{-1} = (L^{-1} P^T dX1 U1^{-1}) o 1_U.
|
||||||
|
//
|
||||||
|
// Hence,
|
||||||
|
// dL = L [(L^{-1} P^T dX1 U1^{-1}) o 1_L],
|
||||||
|
// dU1 = [(L^{-1} P^T dX1 U1^{-1}) o 1_U] U1.
|
||||||
|
// As for dU2, from (*) it follows
|
||||||
|
// P^T dX2 = dL U2 + L dU2 =>
|
||||||
|
// dU2 = L^{-1} (P^T dX2 - dL U2).
|
||||||
|
//
|
||||||
|
// Backward AD:
|
||||||
|
//
|
||||||
|
// The following equality comes very handy:
|
||||||
|
// Tr(A (B o C)) = Tr((A o B^T) C) (!)
|
||||||
|
//
|
||||||
|
// Tr(X_grad^H dX) = Tr(L_grad^H dL) + Tr(U_grad^H dU), then
|
||||||
|
//
|
||||||
|
// Tr(L_grad^H dL) = Tr(L_grad^H L [(L^{-1} P^T dX1 U1^{-1}) o 1_L] = [using (!)]
|
||||||
|
// = Tr((L_grad^H L o 1_L^T) L^{-1} P^T dX1 U1^{-1}) = [using the cyclic property of Tr]
|
||||||
|
// = Tr(U1^{-1} (L_grad^H L o 1_L^T) L^{-1} P^T dX1)
|
||||||
|
//
|
||||||
|
// Similar, using (!) and the cyclic property of the trace operator:
|
||||||
|
// Tr(U_grad^H dU) = Tr(U1_grad^H dU1) + Tr(U2_grad^H dU2)
|
||||||
|
// = Tr(U1^{-1} (U1 U1_grad^H o 1_U^T) L^{-1} P^T dX1)
|
||||||
|
// + Tr(U2_grad^H L^{-1} P^T dX2)
|
||||||
|
// - Tr(U1^{-1} (U2 U2_grad^H o 1_L^T) L^{-1} P^T dX1)
|
||||||
|
//
|
||||||
|
// By combining the matrices to the left from dX1 and dX2 and then applying conjugate transposition,
|
||||||
|
// we finally arrive at:
|
||||||
|
//
|
||||||
|
// X1_grad = P L^{-H} [L^H L_grad o 1_L + U1_grad U1^H o 1_U - U2_grad U2^H o 1_L] U1^{-H},
|
||||||
|
// X2_grad = P L^{-H} U2_grad
|
||||||
|
Tensor plu_backward_base(
|
||||||
|
const variable_list& grads,
|
||||||
|
const Tensor& self,
|
||||||
|
const Tensor& P,
|
||||||
|
const Tensor& L,
|
||||||
|
const Tensor& U) {
|
||||||
|
auto L_grad = grads[0];
|
||||||
|
auto U_grad = grads[1];
|
||||||
|
|
||||||
|
auto m = self.size(-2);
|
||||||
|
auto n = self.size(-1);
|
||||||
|
auto k = std::min(m, n);
|
||||||
|
|
||||||
|
auto L_principal = L.narrow(-2, 0, k).narrow(-1, 0, k);
|
||||||
|
auto L_principal_H = L_principal.transpose(-2, -1).conj();
|
||||||
|
auto L_grad_principal = L_grad.narrow(-2, 0, k).narrow(-1, 0, k);
|
||||||
|
auto U_principal = U.narrow(-2, 0, k).narrow(-1, 0, k);
|
||||||
|
auto U_principal_H = U_principal.transpose(-2, -1).conj();
|
||||||
|
auto U_grad_principal = U_grad.narrow(-2, 0, k).narrow(-1, 0, k);
|
||||||
|
|
||||||
|
auto phi_L = L_principal_H.matmul(L_grad_principal).tril_(-1);
|
||||||
|
auto phi_U = U_grad_principal.matmul(U_principal_H).triu_();
|
||||||
|
|
||||||
|
auto phi = phi_L + phi_U;
|
||||||
|
auto psi = at::zeros_like(self);
|
||||||
|
|
||||||
|
Tensor self_grad;
|
||||||
|
if (m <= n) {
|
||||||
|
auto U_complement = U.narrow(-2, 0, k).narrow(-1, k, n - k);
|
||||||
|
auto U_grad_complement = U_grad.narrow(-2, 0, k).narrow(-1, k, n - k);
|
||||||
|
|
||||||
|
auto phi_complement = U_grad_complement.matmul(U_complement.transpose(-2, -1).conj()).tril_(-1);
|
||||||
|
phi.sub_(phi_complement);
|
||||||
|
|
||||||
|
// recall the result for X1_grad and X2_grad from above.
|
||||||
|
// It can be rewritten as
|
||||||
|
// (X1_grad | X2_grad) = P L^{-H} psi, where
|
||||||
|
// psi = (psi1 | psi2)
|
||||||
|
// = ([L^H L_grad o 1_L + U1_grad U1^H o 1_U - U2_grad U2^H o 1_L] U1^{-H} | U2_grad),
|
||||||
|
// so it is filled in parts.
|
||||||
|
//
|
||||||
|
// fill psi2 in
|
||||||
|
psi.narrow(-2, 0, k).narrow(-1, k, n - k).copy_(U_grad_complement);
|
||||||
|
|
||||||
|
// solve for psi1 to avoid the inversion of U1^H
|
||||||
|
auto psi_principal = std::get<0>(at::triangular_solve(
|
||||||
|
phi.transpose(-2, -1).conj(),
|
||||||
|
U_principal,
|
||||||
|
/*upper=*/true,
|
||||||
|
/*transpose=*/false,
|
||||||
|
/*unitriangular=*/false
|
||||||
|
)).transpose(-2, -1).conj();
|
||||||
|
psi.narrow(-2, 0, k).narrow(-1, 0, k).copy_(psi_principal);
|
||||||
|
|
||||||
|
// solve for the grad to avoid the inversion of L1^H
|
||||||
|
self_grad = P.matmul(
|
||||||
|
std::get<0>(at::triangular_solve(
|
||||||
|
psi,
|
||||||
|
L_principal_H,
|
||||||
|
/*upper=*/true,
|
||||||
|
/*transpose=*/false,
|
||||||
|
/*unitriangular=*/true
|
||||||
|
))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// variables psi and phi carry the same meaning as in the case (m <= n),
|
||||||
|
// albeit they are differently defined.
|
||||||
|
auto L_complement = L.narrow(-2, k, m - k).narrow(-1, 0, k);
|
||||||
|
auto L_grad_complement = L_grad.narrow(-2, k, m - k).narrow(-1, 0, k);
|
||||||
|
|
||||||
|
auto phi_complement = L_complement.transpose(-2, -1).conj().matmul(L_grad_complement).triu_();
|
||||||
|
phi.sub_(phi_complement);
|
||||||
|
|
||||||
|
psi.narrow(-2, k, m - k).narrow(-1, 0, k).copy_(L_grad_complement);
|
||||||
|
|
||||||
|
auto psi_principal = std::get<0>(at::triangular_solve(
|
||||||
|
phi,
|
||||||
|
L_principal_H,
|
||||||
|
/*upper=*/true,
|
||||||
|
/*transpose=*/false,
|
||||||
|
/*unitriangular=*/true
|
||||||
|
));
|
||||||
|
psi.narrow(-2, 0, k).narrow(-1, 0, k).copy_(psi_principal);
|
||||||
|
|
||||||
|
self_grad = std::get<0>(at::triangular_solve(
|
||||||
|
P.matmul(psi).transpose(-2, -1),
|
||||||
|
U_principal.conj(),
|
||||||
|
/*upper=*/true,
|
||||||
|
/*transpose=*/false,
|
||||||
|
/*unitriangular=*/false
|
||||||
|
)).transpose(-2, -1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return self_grad;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor _lu_with_info_backward(
|
||||||
|
const Tensor& grad,
|
||||||
|
const Tensor& self,
|
||||||
|
const Tensor& LU,
|
||||||
|
const Tensor& pivs) {
|
||||||
|
Tensor P, L, U;
|
||||||
|
std::tie(P, L, U) = at::lu_unpack(LU, pivs);
|
||||||
|
// Note that packed LU could be represented as
|
||||||
|
// LU = L + U - I, hence
|
||||||
|
// L_grad = LU_grad,
|
||||||
|
// U_grad = LU_grad.
|
||||||
|
return plu_backward_base({/*L_grad=*/grad, /*U_grad=*/grad}, self, P, L, U);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace details
|
} // namespace details
|
||||||
} // namespace generated
|
} // namespace generated
|
||||||
} // namespace autograd
|
} // namespace autograd
|
||||||
|
|
|
||||||
|
|
@ -258,6 +258,7 @@ Tensor lu_unpack_backward(
|
||||||
const Tensor& LU_data,
|
const Tensor& LU_data,
|
||||||
bool unpack_data
|
bool unpack_data
|
||||||
);
|
);
|
||||||
|
|
||||||
Tensor _det_lu_based_helper_backward(
|
Tensor _det_lu_based_helper_backward(
|
||||||
const Tensor& det_grad,
|
const Tensor& det_grad,
|
||||||
const Tensor& det,
|
const Tensor& det,
|
||||||
|
|
@ -266,6 +267,20 @@ Tensor _det_lu_based_helper_backward(
|
||||||
const Tensor& pivs
|
const Tensor& pivs
|
||||||
);
|
);
|
||||||
|
|
||||||
|
Tensor lu_backward_base(
|
||||||
|
const variable_list& grads,
|
||||||
|
const Tensor& self,
|
||||||
|
const Tensor& P,
|
||||||
|
const Tensor& L,
|
||||||
|
const Tensor& U
|
||||||
|
);
|
||||||
|
Tensor _lu_with_info_backward(
|
||||||
|
const Tensor& grad,
|
||||||
|
const Tensor& self,
|
||||||
|
const Tensor& LU,
|
||||||
|
const Tensor& pivs
|
||||||
|
);
|
||||||
|
|
||||||
Tensor cat_jvp(at::TensorList tensors, int64_t dim);
|
Tensor cat_jvp(at::TensorList tensors, int64_t dim);
|
||||||
Tensor cumprod_jvp(Tensor self_t, Tensor self_p, Tensor result, int dim);
|
Tensor cumprod_jvp(Tensor self_t, Tensor self_p, Tensor result, int dim);
|
||||||
Tensor gather_with_keepdimed_indices(const Tensor& input, int64_t dim, const Tensor& indices, bool keepdim);
|
Tensor gather_with_keepdimed_indices(const Tensor& input, int64_t dim, const Tensor& indices, bool keepdim);
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ from .overrides import (
|
||||||
handle_torch_function)
|
handle_torch_function)
|
||||||
from ._jit_internal import boolean_dispatch, List
|
from ._jit_internal import boolean_dispatch, List
|
||||||
from ._jit_internal import _overload as overload
|
from ._jit_internal import _overload as overload
|
||||||
from torch._autograd_functions import _LU
|
|
||||||
|
|
||||||
Tensor = torch.Tensor
|
Tensor = torch.Tensor
|
||||||
from torch import _VF
|
from torch import _VF
|
||||||
|
|
@ -1459,8 +1458,10 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None):
|
||||||
* ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`.
|
* ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`.
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
The LU factorization does have backward support,
|
The gradients of this function will only be finite when :attr:`A` is full rank.
|
||||||
but only for square inputs of full rank.
|
This is because the LU decomposition is just differentiable at full rank matrices.
|
||||||
|
Furthermore, if :attr:`A` is close to not being full rank,
|
||||||
|
the gradient will be numerically unstable as it depends on the computation of :math:`L^{-1}` and :math:`U^{-1}`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
A (Tensor): the tensor to factor of size :math:`(*, m, n)`
|
A (Tensor): the tensor to factor of size :math:`(*, m, n)`
|
||||||
|
|
@ -1508,23 +1509,6 @@ def _lu_impl(A, pivot=True, get_infos=False, out=None):
|
||||||
... print('LU factorization succeeded for all samples!')
|
... print('LU factorization succeeded for all samples!')
|
||||||
LU factorization succeeded for all samples!
|
LU factorization succeeded for all samples!
|
||||||
"""
|
"""
|
||||||
if not torch._jit_internal.is_scripting():
|
|
||||||
if A.requires_grad:
|
|
||||||
if not (A.size(-2) == A.size(-1) and (A.dtype.is_floating_point or A.is_complex)):
|
|
||||||
raise ValueError(
|
|
||||||
'lu.backward works only with batches of squared full-rank matrices'
|
|
||||||
' of floating or complex types.'
|
|
||||||
)
|
|
||||||
|
|
||||||
return _LU.apply(A, pivot, get_infos)
|
|
||||||
else:
|
|
||||||
if A.requires_grad:
|
|
||||||
raise RuntimeError(
|
|
||||||
'Script and require gradients is not supported at the moment.'
|
|
||||||
'If you just want to do the forward, use .detach()'
|
|
||||||
'on the input before calling the function.'
|
|
||||||
)
|
|
||||||
|
|
||||||
# If get_infos is True, then we don't need to check for errors and vice versa
|
# If get_infos is True, then we don't need to check for errors and vice versa
|
||||||
return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
|
return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3204,8 +3204,8 @@ def sample_inputs_lu(op_info, device, dtype, requires_grad=False, **kwargs):
|
||||||
# not needed once OpInfo tests support Iterables
|
# not needed once OpInfo tests support Iterables
|
||||||
def generate_samples():
|
def generate_samples():
|
||||||
batch_shapes = ((), (3,), (3, 3))
|
batch_shapes = ((), (3,), (3, 3))
|
||||||
for batch_shape, get_infos in product(batch_shapes, (True, False)):
|
for batch_shape, get_infos, size_delta in product(batch_shapes, (True, False), (-2, -1, 0, +1, +2)):
|
||||||
shape = batch_shape + (S, S)
|
shape = batch_shape + (S + size_delta, S)
|
||||||
input = make_tensor(shape, device, dtype, requires_grad=requires_grad, low=None, high=None)
|
input = make_tensor(shape, device, dtype, requires_grad=requires_grad, low=None, high=None)
|
||||||
yield SampleInput(input, args=(True, get_infos))
|
yield SampleInput(input, args=(True, get_infos))
|
||||||
|
|
||||||
|
|
@ -6533,16 +6533,16 @@ op_db: List[OpInfo] = [
|
||||||
op=torch.lu,
|
op=torch.lu,
|
||||||
dtypes=floating_and_complex_types(),
|
dtypes=floating_and_complex_types(),
|
||||||
supports_inplace_autograd=False,
|
supports_inplace_autograd=False,
|
||||||
|
# we use in-place operations which cannot be avoided.
|
||||||
|
# This causes vmap failures, hence we skip batched gradient checks
|
||||||
|
check_batched_grad=False,
|
||||||
check_batched_gradgrad=False,
|
check_batched_gradgrad=False,
|
||||||
supports_out=False,
|
supports_out=False,
|
||||||
sample_inputs_func=sample_inputs_lu,
|
sample_inputs_func=sample_inputs_lu,
|
||||||
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],
|
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],
|
||||||
skips=(
|
skips=(
|
||||||
# we skip jit tests because lu_backward is impelemented as autograd.Function,
|
# we skip jit tests because `lu` is a torch function
|
||||||
# which does not support autograd with scripting
|
|
||||||
SkipInfo('TestJit', 'test_variant_consistency_jit'),
|
SkipInfo('TestJit', 'test_variant_consistency_jit'),
|
||||||
# Skip operator schema test because this is a functional and not an operator
|
|
||||||
SkipInfo('TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
|
|
||||||
)),
|
)),
|
||||||
OpInfo('lu_solve',
|
OpInfo('lu_solve',
|
||||||
op=torch.lu_solve,
|
op=torch.lu_solve,
|
||||||
|
|
@ -6555,7 +6555,7 @@ op_db: List[OpInfo] = [
|
||||||
dtypes=floating_and_complex_types(),
|
dtypes=floating_and_complex_types(),
|
||||||
supports_inplace_autograd=False,
|
supports_inplace_autograd=False,
|
||||||
# we use in-place operations which cannot be avoided.
|
# we use in-place operations which cannot be avoided.
|
||||||
# This cases vmap failures, hence we skip batched gradient checks
|
# This causes vmap failures, hence we skip batched gradient checks
|
||||||
check_batched_grad=False,
|
check_batched_grad=False,
|
||||||
supports_out=True,
|
supports_out=True,
|
||||||
sample_inputs_func=sample_inputs_lu_unpack,
|
sample_inputs_func=sample_inputs_lu_unpack,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user