mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Implement gradient for the residuals of torch.linalg.lstsq (#148526)
Fixes #147543. I have written some tests in python using `gradcheck`. Please advise where I should put these tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148526 Approved by: https://github.com/lezcano
This commit is contained in:
parent
ea86b8d315
commit
59f14d19ae
|
|
@ -1040,9 +1040,10 @@
|
|||
result: logsumexp_jvp(self_p, self_t, dim, keepdim)
|
||||
|
||||
- name: linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)
|
||||
self, b: linalg_lstsq_backward(grad, self, b, grad_input_mask)
|
||||
solution: linalg_lstsq_jvp(self_p, b_p, self_t, b_t)
|
||||
output_differentiability: [True, False, False, False]
|
||||
self, b: linalg_lstsq_backward(grads[0], grads[1], self, b, solution, grad_input_mask)
|
||||
solution: linalg_lstsq_solution_jvp(self_p, b_p, self_t, b_t)
|
||||
residuals: linalg_lstsq_residuals_jvp(self_p, b_p, self_t, b_t, solution, residuals)
|
||||
output_differentiability: [True, True, False, False]
|
||||
|
||||
- name: lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
|
||||
self: zeros_like(self)
|
||||
|
|
|
|||
|
|
@ -3827,27 +3827,64 @@ std::tuple<Tensor, Tensor> linalg_eig_jvp(
|
|||
return std::make_pair(std::move(dL), std::move(dV));
|
||||
}
|
||||
|
||||
Tensor linalg_lstsq_jvp(
|
||||
Tensor linalg_lstsq_solution_jvp(
|
||||
const Tensor& A,
|
||||
const Tensor& B,
|
||||
const Tensor& B_,
|
||||
const Tensor& dA,
|
||||
const Tensor& dB) {
|
||||
const Tensor& dB_) {
|
||||
at::NoTF32Guard disable_tf32;
|
||||
const bool vector_case = at::native::linalg_solve_is_vector_rhs(A, B_);
|
||||
const auto vector_to_matrix = [vector_case](const Tensor& X) {
|
||||
return vector_case ? X.unsqueeze(-1) : X;
|
||||
};
|
||||
const auto matrix_to_vector = [vector_case](const Tensor& X) {
|
||||
return vector_case ? X.squeeze(-1) : X;
|
||||
};
|
||||
auto B = vector_to_matrix(B_);
|
||||
auto dB = vector_to_matrix(dB_);
|
||||
auto pinvA = at::linalg_pinv(A);
|
||||
auto dpinvA = pinv_jvp(A, pinvA, dA);
|
||||
auto dX = dpinvA.matmul(B) + pinvA.matmul(dB);
|
||||
auto dX = matrix_to_vector(dpinvA.matmul(B) + pinvA.matmul(dB));
|
||||
return dX;
|
||||
}
|
||||
|
||||
Tensor linalg_lstsq_residuals_jvp(
|
||||
const Tensor& A,
|
||||
const Tensor& B_,
|
||||
const Tensor& dA,
|
||||
const Tensor& dB_,
|
||||
const Tensor& X_,
|
||||
const Tensor& L) {
|
||||
at::NoTF32Guard disable_tf32;
|
||||
if (L.numel() == 0) {
|
||||
return L.clone();
|
||||
}
|
||||
const bool vector_case = at::native::linalg_solve_is_vector_rhs(A, B_);
|
||||
const auto vector_to_matrix = [vector_case](const Tensor& X) {
|
||||
return vector_case ? X.unsqueeze(-1) : X;
|
||||
};
|
||||
auto B = vector_to_matrix(B_);
|
||||
auto dB = vector_to_matrix(dB_);
|
||||
auto X = vector_to_matrix(X_);
|
||||
auto r = A.matmul(X) - B;
|
||||
auto dr = dA.matmul(X) - dB;
|
||||
// Danskin's theorem lets us compute dL as if X did not depend on A and B
|
||||
auto dL = 2 * at::real(r * dr.conj()).sum(-2);
|
||||
return dL;
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> linalg_lstsq_backward(
|
||||
const Tensor& gX_,
|
||||
const Tensor& gL_,
|
||||
const Tensor& A,
|
||||
const Tensor& B_,
|
||||
const Tensor& X_,
|
||||
const std::array<bool, 2>& grad_input_mask) {
|
||||
at::NoTF32Guard disable_tf32;
|
||||
auto A_requires_grad = grad_input_mask[0];
|
||||
auto B_requires_grad = grad_input_mask[1];
|
||||
if (!gX_.defined() || (!A_requires_grad && !B_requires_grad)) {
|
||||
if ((!gX_.defined() && !gL_.numel()) || // gL_ undefined or have shape [0]
|
||||
(!A_requires_grad && !B_requires_grad)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
|
|
@ -3859,20 +3896,39 @@ std::tuple<Tensor, Tensor> linalg_lstsq_backward(
|
|||
return vector_case ? X.squeeze(-1) : X;
|
||||
};
|
||||
|
||||
auto gX = vector_to_matrix(gX_);
|
||||
auto B = vector_to_matrix(B_);
|
||||
Tensor pinvA = at::linalg_pinv(A);
|
||||
Tensor A_grad, B_grad;
|
||||
if (A_requires_grad) {
|
||||
auto pinvA_grad = gX.matmul(B.mH());
|
||||
A_grad = pinv_backward(pinvA_grad, pinvA, A);
|
||||
Tensor A_grad_X, B_grad_X, A_grad, B_grad;
|
||||
|
||||
if (gX_.defined()) { // Gradient from solution
|
||||
auto gX = vector_to_matrix(gX_);
|
||||
Tensor pinvA = at::linalg_pinv(A);
|
||||
if (A_requires_grad) {
|
||||
auto pinvA_grad = gX.matmul(B.mH());
|
||||
A_grad_X = pinv_backward(pinvA_grad, pinvA, A);
|
||||
}
|
||||
if (B_requires_grad) {
|
||||
// Equivalent to
|
||||
// B_grad = std::get<0>(at::linalg_lstsq(A.mH(), gX, rcond, driver));
|
||||
// but we avoid this approach as `gelsy` is non-deterministic
|
||||
B_grad_X = matrix_to_vector(pinvA.mH().matmul(gX));
|
||||
}
|
||||
}
|
||||
|
||||
if (B_requires_grad) {
|
||||
// Equivalent to
|
||||
// B_grad = std::get<0>(at::linalg_lstsq(A.mH(), gX, rcond, driver));
|
||||
// but we avoid this approach as `gelsy` is non-deterministic
|
||||
B_grad = matrix_to_vector(pinvA.mH().matmul(gX));
|
||||
if (gL_.numel()) { // Gradient from residuals
|
||||
auto X = vector_to_matrix(X_);
|
||||
auto r = A.matmul(X) - B;
|
||||
auto gL = gL_.unsqueeze(-2);
|
||||
if (A_requires_grad) {
|
||||
auto A_grad_L = 2 * (gL * r).matmul(X.mH());
|
||||
A_grad = A_grad_X.defined() ? A_grad_X + A_grad_L : A_grad_L;
|
||||
}
|
||||
if (B_requires_grad) {
|
||||
auto B_grad_L = matrix_to_vector(-2 * gL * r);
|
||||
B_grad = B_grad_X.defined() ? B_grad_X + B_grad_L : B_grad_L;
|
||||
}
|
||||
} else { // gX_.defined() == true
|
||||
A_grad = A_grad_X;
|
||||
B_grad = B_grad_X;
|
||||
}
|
||||
|
||||
return std::make_tuple(A_grad, B_grad);
|
||||
|
|
|
|||
|
|
@ -631,11 +631,18 @@ Tensor linalg_eig_backward(
|
|||
const Tensor& V,
|
||||
const bool is_hermitian,
|
||||
const bool symeig_eigenvectors = true);
|
||||
Tensor linalg_lstsq_jvp(
|
||||
Tensor linalg_lstsq_solution_jvp(
|
||||
const Tensor& A,
|
||||
const Tensor& B,
|
||||
const Tensor& B_,
|
||||
const Tensor& dA,
|
||||
const Tensor& dB);
|
||||
const Tensor& dB_);
|
||||
Tensor linalg_lstsq_residuals_jvp(
|
||||
const Tensor& A,
|
||||
const Tensor& B_,
|
||||
const Tensor& dA,
|
||||
const Tensor& dB_,
|
||||
const Tensor& X_,
|
||||
const Tensor& L);
|
||||
std::tuple<Tensor, Tensor> triangular_solve_backward(
|
||||
const Tensor& grad_x,
|
||||
const Tensor& grad_m,
|
||||
|
|
@ -887,9 +894,11 @@ Tensor linalg_det_jvp(
|
|||
const Tensor& pivots,
|
||||
const bool use_A_T);
|
||||
std::tuple<Tensor, Tensor> linalg_lstsq_backward(
|
||||
const Tensor& grad,
|
||||
const Tensor& gX_,
|
||||
const Tensor& gL,
|
||||
const Tensor& A,
|
||||
const Tensor& B_,
|
||||
const Tensor& X_,
|
||||
const std::array<bool, 2>& grad_input_mask);
|
||||
Tensor linalg_lu_backward(
|
||||
const Tensor& L_grad,
|
||||
|
|
|
|||
|
|
@ -1126,7 +1126,7 @@ of shape `(*, m, n)`, `(*, m, k)` respectively, it contains
|
|||
|
||||
- `solution`: the least squares solution. It has shape `(*, n, k)`.
|
||||
- `residuals`: the squared residuals of the solutions, that is, :math:`\|AX - B\|_F^2`.
|
||||
It has shape equal to the batch dimensions of :attr:`A`.
|
||||
It has shape `(*, k)`.
|
||||
It is computed when `m > n` and every matrix in :attr:`A` is full-rank,
|
||||
otherwise, it is an empty tensor.
|
||||
If :attr:`A` is a batch of matrices and any matrix in the batch is not full rank,
|
||||
|
|
|
|||
|
|
@ -1517,8 +1517,9 @@ op_db: list[OpInfo] = [
|
|||
"linalg.lstsq",
|
||||
aten_name="linalg_lstsq",
|
||||
variant_test_name="grad_oriented",
|
||||
# gradchecks for forward AD fails with multi-Tensor outputs
|
||||
op=lambda a, b, driver: torch.linalg.lstsq(a, b, driver=driver)[0],
|
||||
# gradchecks for forward AD fails with full output tuple
|
||||
# works when taking [:2], which is (solution, residuals)
|
||||
op=lambda a, b, driver: torch.linalg.lstsq(a, b, driver=driver)[:2],
|
||||
supports_out=False,
|
||||
dtypes=floating_and_complex_types(),
|
||||
sample_inputs_func=sample_inputs_linalg_lstsq,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user