Commit Graph

2 Commits

Author SHA1 Message Date
Nikita Vedeneev
dfc7fa03e5 lu_backward: more numerically stable and with complex support. (#53994)
Summary:
As per title.

Numerical stability increased by replacing inverses with solutions to systems of linear triangular equations.

Unblocks computing `torch.det` for FULL-rank inputs of complex dtypes via the LU decomposition once https://github.com/pytorch/pytorch/pull/48125/files is merged:
```
LU, pivots = input.lu()
P, L, U = torch.lu_unpack(LU, pivots)
det_input = P.det() * torch.prod(U.diagonal(0, -1, -2), dim=-1)  # P is not differentiable, so we are fine even if it is complex.
```

Unfortunately, since `lu_backward` is implemented as `autograd.Function`, we cannot support both autograd and scripting at the moment.
The solution would be to move all the lu-related methods to ATen, see https://github.com/pytorch/pytorch/issues/53364.

Resolves https://github.com/pytorch/pytorch/issues/52891
TODOs:
* extend lu_backward for tall/wide matrices of full rank.
* move lu-related functionality to ATen and make it differentiable.
* handle rank-deficient inputs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/53994

Reviewed By: pbelevich

Differential Revision: D27188529

Pulled By: anjali411

fbshipit-source-id: 8e053b240413dbf074904dce01cd564583d1f064
2021-03-25 13:33:58 -07:00
Nikita Vedeneev
c31ced4246 make torch.lu differentiable. (#46284)
Summary:
As per title. Limitations: only for batches of squared full-rank matrices.

CC albanD

Pull Request resolved: https://github.com/pytorch/pytorch/pull/46284

Reviewed By: zou3519

Differential Revision: D24448266

Pulled By: albanD

fbshipit-source-id: d98215166268553a648af6bdec5a32ad601b7814
2020-10-23 10:13:46 -07:00