Implement gradient for the residuals of torch.linalg.lstsq (#148526)

Fixes #147543.

I have written some tests in python using `gradcheck`. Please advise where I should put these tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148526
Approved by: https://github.com/lezcano
This commit is contained in:
Xinyuan Zhao 2025-03-10 12:35:07 +00:00 committed by PyTorch MergeBot
parent ea86b8d315
commit 59f14d19ae
5 changed files with 93 additions and 26 deletions

View File

@ -1040,9 +1040,10 @@
result: logsumexp_jvp(self_p, self_t, dim, keepdim)
- name: linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)
self, b: linalg_lstsq_backward(grad, self, b, grad_input_mask)
solution: linalg_lstsq_jvp(self_p, b_p, self_t, b_t)
output_differentiability: [True, False, False, False]
self, b: linalg_lstsq_backward(grads[0], grads[1], self, b, solution, grad_input_mask)
solution: linalg_lstsq_solution_jvp(self_p, b_p, self_t, b_t)
residuals: linalg_lstsq_residuals_jvp(self_p, b_p, self_t, b_t, solution, residuals)
output_differentiability: [True, True, False, False]
- name: lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
self: zeros_like(self)

View File

@ -3827,27 +3827,64 @@ std::tuple<Tensor, Tensor> linalg_eig_jvp(
return std::make_pair(std::move(dL), std::move(dV));
}
Tensor linalg_lstsq_jvp(
Tensor linalg_lstsq_solution_jvp(
const Tensor& A,
const Tensor& B,
const Tensor& B_,
const Tensor& dA,
const Tensor& dB) {
const Tensor& dB_) {
at::NoTF32Guard disable_tf32;
const bool vector_case = at::native::linalg_solve_is_vector_rhs(A, B_);
const auto vector_to_matrix = [vector_case](const Tensor& X) {
return vector_case ? X.unsqueeze(-1) : X;
};
const auto matrix_to_vector = [vector_case](const Tensor& X) {
return vector_case ? X.squeeze(-1) : X;
};
auto B = vector_to_matrix(B_);
auto dB = vector_to_matrix(dB_);
auto pinvA = at::linalg_pinv(A);
auto dpinvA = pinv_jvp(A, pinvA, dA);
auto dX = dpinvA.matmul(B) + pinvA.matmul(dB);
auto dX = matrix_to_vector(dpinvA.matmul(B) + pinvA.matmul(dB));
return dX;
}
Tensor linalg_lstsq_residuals_jvp(
const Tensor& A,
const Tensor& B_,
const Tensor& dA,
const Tensor& dB_,
const Tensor& X_,
const Tensor& L) {
at::NoTF32Guard disable_tf32;
if (L.numel() == 0) {
return L.clone();
}
const bool vector_case = at::native::linalg_solve_is_vector_rhs(A, B_);
const auto vector_to_matrix = [vector_case](const Tensor& X) {
return vector_case ? X.unsqueeze(-1) : X;
};
auto B = vector_to_matrix(B_);
auto dB = vector_to_matrix(dB_);
auto X = vector_to_matrix(X_);
auto r = A.matmul(X) - B;
auto dr = dA.matmul(X) - dB;
// Danskin's theorem lets us compute dL as if X did not depend on A and B
auto dL = 2 * at::real(r * dr.conj()).sum(-2);
return dL;
}
std::tuple<Tensor, Tensor> linalg_lstsq_backward(
const Tensor& gX_,
const Tensor& gL_,
const Tensor& A,
const Tensor& B_,
const Tensor& X_,
const std::array<bool, 2>& grad_input_mask) {
at::NoTF32Guard disable_tf32;
auto A_requires_grad = grad_input_mask[0];
auto B_requires_grad = grad_input_mask[1];
if (!gX_.defined() || (!A_requires_grad && !B_requires_grad)) {
if ((!gX_.defined() && !gL_.numel()) || // gL_ undefined or have shape [0]
(!A_requires_grad && !B_requires_grad)) {
return {};
}
@ -3859,20 +3896,39 @@ std::tuple<Tensor, Tensor> linalg_lstsq_backward(
return vector_case ? X.squeeze(-1) : X;
};
auto gX = vector_to_matrix(gX_);
auto B = vector_to_matrix(B_);
Tensor pinvA = at::linalg_pinv(A);
Tensor A_grad, B_grad;
if (A_requires_grad) {
auto pinvA_grad = gX.matmul(B.mH());
A_grad = pinv_backward(pinvA_grad, pinvA, A);
Tensor A_grad_X, B_grad_X, A_grad, B_grad;
if (gX_.defined()) { // Gradient from solution
auto gX = vector_to_matrix(gX_);
Tensor pinvA = at::linalg_pinv(A);
if (A_requires_grad) {
auto pinvA_grad = gX.matmul(B.mH());
A_grad_X = pinv_backward(pinvA_grad, pinvA, A);
}
if (B_requires_grad) {
// Equivalent to
// B_grad = std::get<0>(at::linalg_lstsq(A.mH(), gX, rcond, driver));
// but we avoid this approach as `gelsy` is non-deterministic
B_grad_X = matrix_to_vector(pinvA.mH().matmul(gX));
}
}
if (B_requires_grad) {
// Equivalent to
// B_grad = std::get<0>(at::linalg_lstsq(A.mH(), gX, rcond, driver));
// but we avoid this approach as `gelsy` is non-deterministic
B_grad = matrix_to_vector(pinvA.mH().matmul(gX));
if (gL_.numel()) { // Gradient from residuals
auto X = vector_to_matrix(X_);
auto r = A.matmul(X) - B;
auto gL = gL_.unsqueeze(-2);
if (A_requires_grad) {
auto A_grad_L = 2 * (gL * r).matmul(X.mH());
A_grad = A_grad_X.defined() ? A_grad_X + A_grad_L : A_grad_L;
}
if (B_requires_grad) {
auto B_grad_L = matrix_to_vector(-2 * gL * r);
B_grad = B_grad_X.defined() ? B_grad_X + B_grad_L : B_grad_L;
}
} else { // gX_.defined() == true
A_grad = A_grad_X;
B_grad = B_grad_X;
}
return std::make_tuple(A_grad, B_grad);

View File

@ -631,11 +631,18 @@ Tensor linalg_eig_backward(
const Tensor& V,
const bool is_hermitian,
const bool symeig_eigenvectors = true);
Tensor linalg_lstsq_jvp(
Tensor linalg_lstsq_solution_jvp(
const Tensor& A,
const Tensor& B,
const Tensor& B_,
const Tensor& dA,
const Tensor& dB);
const Tensor& dB_);
Tensor linalg_lstsq_residuals_jvp(
const Tensor& A,
const Tensor& B_,
const Tensor& dA,
const Tensor& dB_,
const Tensor& X_,
const Tensor& L);
std::tuple<Tensor, Tensor> triangular_solve_backward(
const Tensor& grad_x,
const Tensor& grad_m,
@ -887,9 +894,11 @@ Tensor linalg_det_jvp(
const Tensor& pivots,
const bool use_A_T);
std::tuple<Tensor, Tensor> linalg_lstsq_backward(
const Tensor& grad,
const Tensor& gX_,
const Tensor& gL,
const Tensor& A,
const Tensor& B_,
const Tensor& X_,
const std::array<bool, 2>& grad_input_mask);
Tensor linalg_lu_backward(
const Tensor& L_grad,

View File

@ -1126,7 +1126,7 @@ of shape `(*, m, n)`, `(*, m, k)` respectively, it contains
- `solution`: the least squares solution. It has shape `(*, n, k)`.
- `residuals`: the squared residuals of the solutions, that is, :math:`\|AX - B\|_F^2`.
It has shape equal to the batch dimensions of :attr:`A`.
It has shape `(*, k)`.
It is computed when `m > n` and every matrix in :attr:`A` is full-rank,
otherwise, it is an empty tensor.
If :attr:`A` is a batch of matrices and any matrix in the batch is not full rank,

View File

@ -1517,8 +1517,9 @@ op_db: list[OpInfo] = [
"linalg.lstsq",
aten_name="linalg_lstsq",
variant_test_name="grad_oriented",
# gradchecks for forward AD fails with multi-Tensor outputs
op=lambda a, b, driver: torch.linalg.lstsq(a, b, driver=driver)[0],
# gradchecks for forward AD fails with full output tuple
# works when taking [:2], which is (solution, residuals)
op=lambda a, b, driver: torch.linalg.lstsq(a, b, driver=driver)[:2],
supports_out=False,
dtypes=floating_and_complex_types(),
sample_inputs_func=sample_inputs_linalg_lstsq,