mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PowerSGD] Add orthogonalization with QR factorization (#72043)
Summary: ### 🚀 The feature, motivation and pitch Following the discussion in https://github.com/pytorch/pytorch/issues/65813, I added the QR factorization to powerSGD_hook.py Gram-Schmidt orthogonalization can't be fully replaced because _torch.linalg.qr_ doesn't work with half-precision. Moreover, in my tests, it works faster with a rank lesser than 3. This is one sample experiment timing powerSGD_hook on ResNext101 with the two different methods:  ### Alternatives Use _torch.orgqr(*torch.geqrf(matrix))_. From my tests it performances are similar to _torch.linalg.qr_. ### Additional context _No response_ Pull Request resolved: https://github.com/pytorch/pytorch/pull/72043 Reviewed By: albanD Differential Revision: D34042781 Pulled By: cbalioglu fbshipit-source-id: e331179d3b7ac40d445b651fc473b16ae4ead462
This commit is contained in:
parent
b1ba221acc
commit
f64bf3839a
|
|
@ -9,13 +9,30 @@ from . import default_hooks as default
|
|||
|
||||
|
||||
def _orthogonalize(matrix, epsilon=0):
|
||||
"""
|
||||
Decide between Gram-Schmidt or QR factorization to orthogonalize the matrix.
|
||||
QR factorization doesn't work with half-precision, but it is usually faster with a rank > 2.
|
||||
"""
|
||||
assert len(matrix.shape) == 2 and matrix.shape[1] <= matrix.shape[0]
|
||||
|
||||
rank = matrix.shape[1]
|
||||
dtype = matrix.dtype
|
||||
if rank <= 2 or dtype in [torch.float16, torch.bfloat16]:
|
||||
_orthogonalize_gram_schmidt(matrix, epsilon=epsilon)
|
||||
else:
|
||||
torch.linalg.qr(
|
||||
matrix,
|
||||
out=(
|
||||
matrix,
|
||||
torch.empty(rank, rank, device=matrix.device, dtype=dtype)
|
||||
)
|
||||
)
|
||||
|
||||
def _orthogonalize_gram_schmidt(matrix, epsilon=0):
|
||||
"""
|
||||
Applies Gram-Schmidt procedure to orthogonalize a given 2D tensor.
|
||||
If epsilon is 0, this is equivalent to `torch.qr(matrix, out=(matrix, _))`,
|
||||
"""
|
||||
# TODO Consider using Q = torch.orgqr(*torch.geqrf(A)) to compute the Q of the QR _much_ faster
|
||||
# and more reliably.
|
||||
# Works on FP32/64 or complex numbers (does not work for half precision)
|
||||
num_cols = matrix.shape[1]
|
||||
for i in range(num_cols):
|
||||
# Normalize the i'th column.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user