diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index f6622ae9cea..b4faf2a312e 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -9,13 +9,30 @@ from . import default_hooks as default def _orthogonalize(matrix, epsilon=0): + """ + Decide between Gram-Schmidt or QR factorization to orthogonalize the matrix. + QR factorization doesn't work with half-precision, but it is usually faster with a rank > 2. + """ + assert len(matrix.shape) == 2 and matrix.shape[1] <= matrix.shape[0] + + rank = matrix.shape[1] + dtype = matrix.dtype + if rank <= 2 or dtype in [torch.float16, torch.bfloat16]: + _orthogonalize_gram_schmidt(matrix, epsilon=epsilon) + else: + torch.linalg.qr( + matrix, + out=( + matrix, + torch.empty(rank, rank, device=matrix.device, dtype=dtype) + ) + ) + +def _orthogonalize_gram_schmidt(matrix, epsilon=0): """ Applies Gram-Schmidt procedure to orthogonalize a given 2D tensor. If epsilon is 0, this is equivalent to `torch.qr(matrix, out=(matrix, _))`, """ - # TODO Consider using Q = torch.orgqr(*torch.geqrf(A)) to compute the Q of the QR _much_ faster - # and more reliably. - # Works on FP32/64 or complex numbers (does not work for half precision) num_cols = matrix.shape[1] for i in range(num_cols): # Normalize the i'th column.