[PowerSGD] Add orthogonalization with QR factorization (#72043)

Summary:
### 🚀 The feature, motivation and pitch
Following the discussion in https://github.com/pytorch/pytorch/issues/65813, I added the QR factorization to powerSGD_hook.py
Gram-Schmidt orthogonalization can't be fully replaced because _torch.linalg.qr_ doesn't work with half-precision. Moreover, in my tests, it works faster with a rank lesser than 3.

This is one sample experiment timing powerSGD_hook on ResNext101 with the two different methods:
![Screenshot from 2022-01-31 18-14-00](https://user-images.githubusercontent.com/42100908/151840929-270c67dd-9fe7-4f11-8e70-8bf2d0ba678d.png)

### Alternatives
Use _torch.orgqr(*torch.geqrf(matrix))_. From my tests it performances are similar to _torch.linalg.qr_.

### Additional context
_No response_

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72043

Reviewed By: albanD

Differential Revision: D34042781

Pulled By: cbalioglu

fbshipit-source-id: e331179d3b7ac40d445b651fc473b16ae4ead462
This commit is contained in:
Omar 2022-02-07 12:54:49 -08:00 committed by Facebook GitHub Bot
parent b1ba221acc
commit f64bf3839a

View File

@ -9,13 +9,30 @@ from . import default_hooks as default
def _orthogonalize(matrix, epsilon=0):
"""
Decide between Gram-Schmidt or QR factorization to orthogonalize the matrix.
QR factorization doesn't work with half-precision, but it is usually faster with a rank > 2.
"""
assert len(matrix.shape) == 2 and matrix.shape[1] <= matrix.shape[0]
rank = matrix.shape[1]
dtype = matrix.dtype
if rank <= 2 or dtype in [torch.float16, torch.bfloat16]:
_orthogonalize_gram_schmidt(matrix, epsilon=epsilon)
else:
torch.linalg.qr(
matrix,
out=(
matrix,
torch.empty(rank, rank, device=matrix.device, dtype=dtype)
)
)
def _orthogonalize_gram_schmidt(matrix, epsilon=0):
"""
Applies Gram-Schmidt procedure to orthogonalize a given 2D tensor.
If epsilon is 0, this is equivalent to `torch.qr(matrix, out=(matrix, _))`,
"""
# TODO Consider using Q = torch.orgqr(*torch.geqrf(A)) to compute the Q of the QR _much_ faster
# and more reliably.
# Works on FP32/64 or complex numbers (does not work for half precision)
num_cols = matrix.shape[1]
for i in range(num_cols):
# Normalize the i'th column.