mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert D25511543: [Gradient Compression] Implement the original layerwise PowerSGD
Test Plan: revert-hammer
Differential Revision:
D25511543 (71f3399e19)
Original commit changeset: 19ef188bc2d4
fbshipit-source-id: a363641a059aeacc57684884998cf8fb7363d748
This commit is contained in:
parent
5cde23fdd4
commit
ad9923e5d5
|
|
@ -66,17 +66,6 @@ class DDPCommHookType(Enum):
|
|||
comm_hook=powerSGD.powerSGD_hook,
|
||||
matrix_approximation_rank=2,
|
||||
)
|
||||
# Batching can lead to a faster training at the cost of accuracy.
|
||||
BATCHED_POWER_SGD = partial(
|
||||
_powerSGD_comm_hook_wrapper,
|
||||
comm_hook=powerSGD.batched_powerSGD_hook,
|
||||
matrix_approximation_rank=1,
|
||||
)
|
||||
BATCHED_POWER_SGD_RANK2 = partial(
|
||||
_powerSGD_comm_hook_wrapper,
|
||||
comm_hook=powerSGD.batched_powerSGD_hook,
|
||||
matrix_approximation_rank=2,
|
||||
)
|
||||
|
||||
|
||||
def register_ddp_comm_hook(
|
||||
|
|
|
|||
|
|
@ -73,195 +73,12 @@ class PowerSGDState(object):
|
|||
def powerSGD_hook(
|
||||
state: PowerSGDState,
|
||||
bucket,
|
||||
) -> torch.futures.Future:
|
||||
"""
|
||||
This DDP communication hook implements the original PowerSGD gradient compression
|
||||
algorithm described in https://arxiv.org/abs/1905.13727.
|
||||
Once gradient tensors are aggregated across all workers, this hook applies
|
||||
compression as follows:
|
||||
1) Views the input flattened 1D gradient tensor as two groups of per-parameter tensors:
|
||||
high-rank tensors and vector-like rank-1 tensors (for biases).
|
||||
2) Handles rank-1 tensors by allreducing them without compression:
|
||||
2.1) Allocate contiguous memory for those rank-1 tensors,
|
||||
and allreduces all the rank-1 tensors as a batch, without compression;
|
||||
2.2) Copies the indvidual rank-1 tensors from the contiguous memory back to the input tensor.
|
||||
3) Handles high-rank tensors by PowerSGD compression:
|
||||
3.1) For each high-rank tensor M, creates two low-rank tensors P and Q for decomposing M,
|
||||
such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
|
||||
3.2) Computes each P in Ps, which is equal to MQ;
|
||||
3.3) Allreduces Ps as a batch;
|
||||
3.4) Orthogonizes each P in Ps;
|
||||
3.5) Computes each Q in Qs, which is approximately equal to M^TP;
|
||||
3.6) Allreduces Qs as a batch;
|
||||
3.7) Computes each M among all the high-rank tensors, which is approximately equal to PQ^T.
|
||||
|
||||
TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration --
|
||||
one left multiplication and one right multiplication.
|
||||
For warm start, can take one such step at a time, and alternate between them.
|
||||
|
||||
Arguments:
|
||||
state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc.
|
||||
bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
|
||||
Note that since DDP comm hook only supports single process single device mode at this time,
|
||||
only exactly one tensor is stored in this bucket.
|
||||
|
||||
Returns:
|
||||
Future handler of the communication, which updates the gradients in place.
|
||||
|
||||
Example::
|
||||
state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
|
||||
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
|
||||
"""
|
||||
process_group = state.process_group
|
||||
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
||||
world_size = (
|
||||
process_group.size() if process_group is not None else dist.get_world_size()
|
||||
)
|
||||
|
||||
# The input tensor is a flattened 1D tensor.
|
||||
input_tensor = bucket.get_tensors()[0]
|
||||
device = input_tensor.device
|
||||
dtype = input_tensor.dtype
|
||||
# Unflatten the input tensor into per-parameter tensors, for layer-wise compression.
|
||||
tensors = [
|
||||
input_tensor[offset : offset + length].view(sizes)
|
||||
for offset, length, sizes in zip(
|
||||
bucket.get_offsets(), bucket.get_lengths(), bucket.get_sizes_list()
|
||||
)
|
||||
]
|
||||
|
||||
# Step I: Handle rank-1 tensors.
|
||||
# Allocate contiguous memory for rank-1 tensors to allreduce them without compression efficiently.
|
||||
rank1_tensors = [tensor for tensor in tensors if tensor.ndimension() <= 1]
|
||||
rank1_tensors_memory = (
|
||||
torch.cat([tensor.view(-1) for tensor in rank1_tensors])
|
||||
if rank1_tensors
|
||||
else torch.tensor([], device=device)
|
||||
)
|
||||
|
||||
# Step II: Handle high-rank tensors.
|
||||
# Allocate contiguous memory for Ps and Qs to allreduce compressed high-rank tensors efficiently.
|
||||
high_rank_tensors = [
|
||||
tensor.view(tensor.shape[0], -1)
|
||||
for tensor in tensors
|
||||
if tensor.ndimension() > 1
|
||||
]
|
||||
total_Ps_size = 0
|
||||
ps_memory = None # TODO(wayi): Store it in a dict of PowerState for warm-up.
|
||||
total_Qs_size = 0
|
||||
qs_memory = None # TODO(wayi): Store it in a dict of PowerState for warm-up.
|
||||
for tensor in high_rank_tensors:
|
||||
n, m = tensor.shape
|
||||
matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
|
||||
total_Ps_size += n * matrix_approximation_rank
|
||||
total_Qs_size += m * matrix_approximation_rank
|
||||
ps_memory = torch.empty(total_Ps_size, device=device, dtype=dtype)
|
||||
qs_memory = torch.empty(total_Qs_size, device=device, dtype=dtype)
|
||||
|
||||
# Create Ps and Qs that point to the allocated memory.
|
||||
ps = []
|
||||
qs = []
|
||||
p_idx = 0
|
||||
q_idx = 0
|
||||
for tensor in high_rank_tensors:
|
||||
n, m = tensor.shape
|
||||
matrix_approximation_rank = min(n, m, state.matrix_approximation_rank)
|
||||
ps.append(
|
||||
ps_memory[p_idx : p_idx + n * matrix_approximation_rank].view(
|
||||
n, matrix_approximation_rank
|
||||
)
|
||||
)
|
||||
qs.append(
|
||||
qs_memory[q_idx : q_idx + m * matrix_approximation_rank].view(
|
||||
m, matrix_approximation_rank
|
||||
)
|
||||
)
|
||||
p_idx += n * matrix_approximation_rank
|
||||
q_idx += m * matrix_approximation_rank
|
||||
|
||||
# Initialize and then orthogonalize Qs.
|
||||
with torch.random.fork_rng(devices=[]):
|
||||
# Fork this RNG to avoid changing the seed globally and affecting the random sampling anywhere else in the training.
|
||||
# The seed makes sure that the initial random values are the same across all the DDP replicas.
|
||||
# Such seed should differ at every step.
|
||||
# Since it is very slow to fork RNG state across all the CUDA devices,
|
||||
# only fork on CPU and then move the generated tensor to the CUDA device.
|
||||
torch.manual_seed(state.rng.randint(1_000_000_000))
|
||||
for q in qs:
|
||||
q.data = torch.randn(
|
||||
*q.shape,
|
||||
device="cpu",
|
||||
dtype=dtype,
|
||||
).to(device)
|
||||
_orthogonalize(q)
|
||||
|
||||
# Compute Ps.
|
||||
for tensor, q, p in zip(high_rank_tensors, qs, ps):
|
||||
torch.matmul(tensor, q, out=p)
|
||||
|
||||
# This allreduce is only applied to rank-1 tensors,
|
||||
# so it should have been kicked off before the above computation on the high-rank tensors to hide more communication costs.
|
||||
# However, this somehow requires a separate future chain at this time.
|
||||
allreduce_contiguous_rank1_tensors_fut = dist.all_reduce(
|
||||
rank1_tensors_memory, group=group_to_use, async_op=True
|
||||
).get_future()
|
||||
|
||||
def unpack_rank1_tensors_and_allreduce_ps(fut):
|
||||
rank1_tensors_memory = fut.value()[0].div_(world_size)
|
||||
idx = 0
|
||||
for tensor in rank1_tensors:
|
||||
tensor.copy_(rank1_tensors_memory[idx : idx + tensor.shape[0]])
|
||||
idx += tensor.shape[0]
|
||||
|
||||
# Since these Ps will be orthogonized later, no need to divide them by world size.
|
||||
return [
|
||||
dist.all_reduce(ps_memory, group=group_to_use, async_op=True)
|
||||
.get_future()
|
||||
.wait()[0]
|
||||
]
|
||||
|
||||
def compute_qs(fut):
|
||||
ps_memory = fut.wait()[0]
|
||||
for p in ps:
|
||||
_orthogonalize(p)
|
||||
|
||||
# Compute Qs.
|
||||
for tensor, p, q in zip(high_rank_tensors, ps, qs):
|
||||
torch.matmul(tensor.t(), p, out=q)
|
||||
|
||||
# Allreduce Qs.
|
||||
return [
|
||||
dist.all_reduce(qs_memory, group=group_to_use, async_op=True)
|
||||
.get_future()
|
||||
.wait()[0]
|
||||
]
|
||||
|
||||
def decompress(fut):
|
||||
qs_memory = fut.wait()[0].div_(world_size)
|
||||
|
||||
for p, q, tensor in zip(ps, qs, high_rank_tensors):
|
||||
torch.matmul(p, q.t(), out=tensor)
|
||||
assert not torch.any(torch.isnan(tensor))
|
||||
return [input_tensor]
|
||||
|
||||
return (
|
||||
allreduce_contiguous_rank1_tensors_fut.then(
|
||||
unpack_rank1_tensors_and_allreduce_ps
|
||||
)
|
||||
.then(compute_qs)
|
||||
.then(decompress)
|
||||
)
|
||||
|
||||
|
||||
def batched_powerSGD_hook(
|
||||
state: PowerSGDState,
|
||||
bucket,
|
||||
) -> torch.futures.Future:
|
||||
"""
|
||||
This DDP communication hook implements a simplified PowerSGD gradient compression
|
||||
algorithm described in https://arxiv.org/abs/1905.13727.
|
||||
Once gradient tensors are aggregated across all workers, this hook applies
|
||||
compression to the flattened input tensor that batches per-parameter tensors as follows:
|
||||
compression as follows:
|
||||
1) Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
|
||||
2) Creates two low-rank tensors P and Q for decomposing M,
|
||||
such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
|
||||
|
|
@ -288,7 +105,7 @@ def batched_powerSGD_hook(
|
|||
|
||||
Example::
|
||||
state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1)
|
||||
>>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
|
||||
>>> ddp_model.register_comm_hook(state, powerSGD_hook)
|
||||
"""
|
||||
process_group = state.process_group
|
||||
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from typing import Union, NamedTuple
|
|||
import torch
|
||||
import torch.cuda
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.nn.parallel.distributed import _dump_DDP_relevant_env_vars
|
||||
import torch.nn as nn
|
||||
|
|
@ -2820,52 +2819,6 @@ class DistributedTest:
|
|||
msg=f"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}",
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
BACKEND != "nccl",
|
||||
"Only NCCL backend support DistributedDataParallel",
|
||||
)
|
||||
@skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
|
||||
@skip_if_rocm
|
||||
def test_DistributedDataParallel_powerSGD_ddp_comm_hook(self):
|
||||
stream = torch.cuda.Stream(self.rank)
|
||||
rank = self.rank
|
||||
with torch.cuda.stream(stream):
|
||||
net = torch.nn.parallel.DistributedDataParallel(
|
||||
torch.nn.Linear(1, 5).to(rank), device_ids=[rank]
|
||||
)
|
||||
process_group = torch.distributed.new_group([0, 1])
|
||||
state = powerSGD.PowerSGDState(
|
||||
process_group=process_group, matrix_approximation_rank=1
|
||||
)
|
||||
net.register_comm_hook(state=state, hook=powerSGD.powerSGD_hook)
|
||||
# NOTE: batched_powerSGD_hook cannot pass the following test, because it has a lower accuracy.
|
||||
for i in range(1000):
|
||||
# Clear gradients manually.
|
||||
grad = net.module.weight.grad
|
||||
if grad is not None:
|
||||
grad.requires_grad_(False)
|
||||
grad.zero_()
|
||||
# Forward + BW
|
||||
batch = torch.tensor([rank]).float().cuda(rank)
|
||||
loss = net(batch).sum()
|
||||
loss.backward()
|
||||
# For each worker, the gradient on the weight should be worker_rank.
|
||||
grad = net.module.weight.grad
|
||||
avg = grad.clone()
|
||||
# All-reducing the gradient averages should give us the gradient
|
||||
# average. If not, then one of the workers has not correctly
|
||||
# written back the averaged gradient before this all-reduce call.
|
||||
dist.all_reduce(avg)
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
avg.div_(world_size)
|
||||
expected_grad = sum(i for i in range(world_size)) / world_size
|
||||
self.assertEqual(
|
||||
avg[0, 0],
|
||||
expected_grad,
|
||||
msg=f"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}",
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(BACKEND != 'nccl' and BACKEND != 'gloo',
|
||||
"Only Nccl & Gloo backend support DistributedDataParallel")
|
||||
@skip_if_no_gpu
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user