mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
**Summary:** Proof that new replicate API is composable with TP **Test Case** 1. pytest test/distributed/_composable/test_replicate_training.py -k test_replicate_tp Pull Request resolved: https://github.com/pytorch/pytorch/pull/162853 Approved by: https://github.com/mori360 ghstack dependencies: #162830, #162836, #162839, #162851
1241 lines
48 KiB
Python
1241 lines
48 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import contextlib
|
|
import copy
|
|
import functools
|
|
import itertools
|
|
import unittest
|
|
from collections import defaultdict
|
|
from collections.abc import Iterable
|
|
from typing import Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch.distributed._composable import checkpoint
|
|
from torch.distributed._composable.replicate_with_fsdp import replicate
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|
_CHECKPOINT_PREFIX,
|
|
apply_activation_checkpointing,
|
|
)
|
|
from torch.distributed.device_mesh import DeviceMesh
|
|
from torch.distributed.fsdp import (
|
|
CPUOffloadPolicy,
|
|
FSDPModule,
|
|
OffloadPolicy,
|
|
register_fsdp_forward_method,
|
|
)
|
|
from torch.distributed.tensor import DTensor, init_device_mesh
|
|
from torch.distributed.tensor.debug import CommDebugMode
|
|
from torch.distributed.tensor.parallel import (
|
|
ColwiseParallel,
|
|
parallelize_module,
|
|
RowwiseParallel,
|
|
)
|
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
|
from torch.testing._internal.common_fsdp import (
|
|
check_sharded_parity,
|
|
compiled_fsdp_test,
|
|
FSDPTest,
|
|
FSDPTestMultiThread,
|
|
MLP,
|
|
MLPStack,
|
|
patch_all_gather,
|
|
patch_reduce_scatter,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
get_cycles_per_ms,
|
|
run_tests,
|
|
TEST_HPU,
|
|
wrapSwapTensorsTest,
|
|
)
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
ModelArgs,
|
|
Transformer,
|
|
TransformerBlock,
|
|
)
|
|
|
|
|
|
c10d_ops = torch.ops.c10d
|
|
funcol = torch.ops.c10d_functional
|
|
|
|
from torch.testing._internal.common_fsdp import get_devtype
|
|
|
|
|
|
device_type = torch.device(get_devtype())
|
|
|
|
|
|
class TestReplicateForwardInputs(FSDPTestMultiThread):
|
|
@property
|
|
def world_size(self) -> int:
|
|
return 2
|
|
|
|
@skip_if_lt_x_gpu(1)
|
|
def test_root_move_forward_input_to_device(self):
|
|
device = torch.device(device_type.type, 0)
|
|
|
|
class ParamlessModule(nn.Module):
|
|
def forward(self, x: torch.Tensor, ys: tuple[torch.Tensor, ...]):
|
|
# Check that Replicate moved the inputs to GPU, including recursing
|
|
# into the tuple data structure
|
|
assert x.device == device, f"Expects {device} but got {x.device}"
|
|
assert ys[0].device == device, (
|
|
f"Expects {device} but got {ys[0].device}"
|
|
)
|
|
assert ys[1].device == device, (
|
|
f"Expects {device} but got {ys[1].device}"
|
|
)
|
|
y = ys[0] + ys[1]
|
|
return x + y + 1
|
|
|
|
model = ParamlessModule().to(device)
|
|
replicate(model).to(device)
|
|
x = torch.randn((3,))
|
|
ys = (torch.randn((3,)), torch.randn((3,)))
|
|
self.assertEqual(x.device, torch.device("cpu"))
|
|
self.assertEqual(ys[0].device, torch.device("cpu"))
|
|
self.assertEqual(ys[1].device, torch.device("cpu"))
|
|
model(x, ys)
|
|
|
|
|
|
class TestReplicateRegisteredParams(FSDPTestMultiThread):
|
|
@property
|
|
def world_size(self) -> int:
|
|
return 4
|
|
|
|
@skip_if_lt_x_gpu(1)
|
|
def test_param_registration_after_forward(self):
|
|
"""Tests the parameter registration after forward."""
|
|
device = torch.device(device_type.type, 0)
|
|
# Single Replicate group
|
|
for reshard_after_forward in (True, False, None):
|
|
torch.manual_seed(42)
|
|
model = MLP(3, device)
|
|
# Since seed is per process, not per thread, we broadcast to ensure
|
|
# the same parameters across ranks
|
|
for param in model.parameters():
|
|
dist.broadcast(param, src=0)
|
|
ref_model = copy.deepcopy(model)
|
|
replicate(model, reshard_after_forward=reshard_after_forward) # root only
|
|
inp = torch.randn((2, 3), device=device_type.type)
|
|
self._assert_dtensor_params(model.parameters())
|
|
self._assert_same_params(model.parameters(), ref_model.parameters())
|
|
model(inp)
|
|
if reshard_after_forward:
|
|
self._assert_dtensor_params(model.parameters())
|
|
else:
|
|
self._assert_tensor_params(model.parameters())
|
|
self._assert_same_params(model.parameters(), ref_model.parameters())
|
|
model.reshard() # however, we can manually reshard
|
|
self._assert_dtensor_params(model.parameters())
|
|
self._assert_same_params(model.parameters(), ref_model.parameters())
|
|
|
|
# Multiple Replicate groups
|
|
for reshard_after_forward in (True, False, None):
|
|
torch.manual_seed(42)
|
|
model = nn.Sequential(MLP(3, device), MLP(3, device))
|
|
for param in model.parameters():
|
|
dist.broadcast(param, src=0)
|
|
ref_model = copy.deepcopy(model)
|
|
replicate(model[0].in_proj, reshard_after_forward=reshard_after_forward)
|
|
replicate(model[0].out_proj, reshard_after_forward=reshard_after_forward)
|
|
replicate(model, reshard_after_forward=reshard_after_forward)
|
|
|
|
self._assert_dtensor_params(model.parameters())
|
|
self._assert_same_params(model.parameters(), ref_model.parameters())
|
|
model(inp)
|
|
non_root_params = list(model[0].in_proj.parameters()) + list(
|
|
model[0].out_proj.parameters()
|
|
)
|
|
root_params = list(set(model.parameters()) - set(non_root_params))
|
|
if reshard_after_forward is None:
|
|
self._assert_dtensor_params(non_root_params)
|
|
self._assert_tensor_params(root_params)
|
|
elif reshard_after_forward:
|
|
self._assert_dtensor_params(non_root_params)
|
|
self._assert_dtensor_params(root_params)
|
|
else:
|
|
self._assert_tensor_params(non_root_params)
|
|
self._assert_tensor_params(root_params)
|
|
self._assert_same_params(model.parameters(), ref_model.parameters())
|
|
for module in model.modules():
|
|
if isinstance(module, FSDPModule):
|
|
module.reshard() # however, we can manually reshard
|
|
self._assert_dtensor_params(model.parameters())
|
|
self._assert_same_params(model.parameters(), ref_model.parameters())
|
|
|
|
@skip_if_lt_x_gpu(1)
|
|
def test_param_registration_after_backward(self):
|
|
"""Tests the parameter registration after backward."""
|
|
device = torch.device(device_type.type, 0)
|
|
# Single Replicate group
|
|
for reshard_after_forward in (True, False):
|
|
model = MLP(8, device)
|
|
replicate(model, reshard_after_forward=reshard_after_forward) # root only
|
|
inp = torch.randn((2, 8), device=device_type.type)
|
|
self._assert_dtensor_params(model.parameters())
|
|
model(inp).sum().backward()
|
|
self._assert_dtensor_params(model.parameters())
|
|
|
|
# Multiple Replicate groups
|
|
for reshard_after_forward in (True, False):
|
|
model = MLP(8, device)
|
|
replicate(model.in_proj, reshard_after_forward=reshard_after_forward)
|
|
replicate(model.out_proj, reshard_after_forward=reshard_after_forward)
|
|
replicate(model, reshard_after_forward=reshard_after_forward)
|
|
self._assert_dtensor_params(model.parameters())
|
|
model(inp).sum().backward()
|
|
self._assert_dtensor_params(model.parameters())
|
|
|
|
def _assert_tensor_params(self, params: Iterable[nn.Parameter]):
|
|
# need to iterate over the list multiple times
|
|
params = list(params)
|
|
self.assertGreater(len(params), 0)
|
|
for param in params:
|
|
self.assertNotIsInstance(param, DTensor)
|
|
self.assertIsInstance(param, torch.Tensor)
|
|
|
|
def _assert_dtensor_params(self, params: Iterable[nn.Parameter]):
|
|
params = list(params)
|
|
self.assertGreater(len(params), 0)
|
|
for param in params:
|
|
self.assertIsInstance(param, DTensor)
|
|
|
|
def _assert_same_params(
|
|
self, params: Iterable[nn.Parameter], ref_params: Iterable[nn.Parameter]
|
|
):
|
|
params, ref_params = list(params), list(ref_params)
|
|
self.assertEqual(len(params), len(ref_params))
|
|
for param, ref_param in zip(params, ref_params):
|
|
if isinstance(param, DTensor):
|
|
param = param.full_tensor()
|
|
self.assertEqual(param.shape, ref_param.shape)
|
|
self.assertEqual(param, ref_param)
|
|
|
|
|
|
class TestReplicateCastAfterInit(FSDPTestMultiThread):
|
|
@property
|
|
def world_size(self) -> int:
|
|
return 2
|
|
|
|
@skip_if_lt_x_gpu(1)
|
|
@wrapSwapTensorsTest(True)
|
|
def test_to_float64_after_init(self):
|
|
"""Tests that the user can cast the module to float64 after init."""
|
|
# NOTE: Test fp64 instead of a lower precision dtype like bf16 for
|
|
# better numerics. The important part is changing the dtype.
|
|
|
|
torch.manual_seed(42)
|
|
mlp_dim, device, dtype = 4, device_type, torch.float64
|
|
model = MLP(mlp_dim, device=device)
|
|
for param in model.parameters():
|
|
dist.broadcast(param, src=0)
|
|
ref_model = copy.deepcopy(model).to(dtype)
|
|
|
|
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
|
for module in (model.in_proj, model.out_proj, model):
|
|
replicate(module)
|
|
model.to(dtype)
|
|
for param in model.parameters():
|
|
self.assertEqual(param.dtype, dtype)
|
|
self.assertEqual(param.to_local().dtype, dtype)
|
|
self.assertEqual(param._spec.tensor_meta.dtype, dtype)
|
|
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
|
|
check_sharded_parity(self, ref_model, model)
|
|
torch.manual_seed(42 + self.rank + 1)
|
|
inp = torch.randn((2, mlp_dim), device=device_type.type, dtype=dtype)
|
|
for iter_idx in range(10):
|
|
losses: list[torch.Tensor] = []
|
|
for _model in (ref_model, model):
|
|
losses.append(_model(inp).sum())
|
|
losses[-1].backward()
|
|
|
|
for param in ref_model.parameters():
|
|
if param.grad is not None:
|
|
dist.all_reduce(param.grad)
|
|
param.grad.div_(self.world_size)
|
|
|
|
self.assertEqual(losses[0], losses[1])
|
|
check_sharded_parity(self, ref_model, model)
|
|
for param in model.parameters():
|
|
self.assertEqual(param.dtype, dtype)
|
|
self.assertEqual(param.to_local().dtype, dtype)
|
|
self.assertEqual(param._spec.tensor_meta.dtype, dtype)
|
|
self.assertEqual(param.grad.dtype, dtype)
|
|
self.assertEqual(param.grad.to_local().dtype, dtype)
|
|
self.assertEqual(param.grad._spec.tensor_meta.dtype, dtype)
|
|
for _optim in (ref_optim, optim):
|
|
_optim.step()
|
|
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
|
|
|
|
|
class TestReplicate1DTrainingCore(FSDPTest):
|
|
@property
|
|
def world_size(self) -> int:
|
|
return min(8, torch.get_device_module(device_type).device_count())
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_train_parity_single_group(self):
|
|
"""
|
|
Tests train parity with DDP for a single FSDP group when sharding
|
|
parameters on dim-0.
|
|
"""
|
|
self.run_subtests(
|
|
{
|
|
"lin_shapes": [
|
|
[(16, 15), (15, 8)],
|
|
[(7, 15), (15, 3)],
|
|
[(16, 17), (17, 8)],
|
|
],
|
|
"use_shard_placement_fn": [False],
|
|
},
|
|
self._test_train_parity_single_group,
|
|
)
|
|
|
|
def _test_train_parity_single_group(
|
|
self, lin_shapes: list[tuple[int, int]], use_shard_placement_fn: bool
|
|
):
|
|
torch.manual_seed(42)
|
|
model = nn.Sequential(
|
|
nn.Linear(*lin_shapes[0]), nn.ReLU(), nn.Linear(*lin_shapes[1])
|
|
)
|
|
ref_model = copy.deepcopy(model).to(device_type)
|
|
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
|
|
|
replicate(model)
|
|
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
|
torch.manual_seed(42 + self.rank + 1)
|
|
inp = (torch.randn((4, lin_shapes[0][0]), device=device_type.type),)
|
|
for iter_idx in range(10):
|
|
losses: list[torch.Tensor] = []
|
|
for _model in (ref_model, model):
|
|
losses.append(_model(*inp).sum())
|
|
losses[-1].backward()
|
|
|
|
for param in ref_model.parameters():
|
|
if param.grad is not None:
|
|
dist.all_reduce(param.grad)
|
|
param.grad.div_(self.world_size)
|
|
|
|
for _optim in (ref_optim, optim):
|
|
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
|
_optim.step()
|
|
self.assertEqual(losses[0], losses[1])
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@unittest.skipIf(TEST_HPU, "Sleep kernel not supported for HPU")
|
|
@compiled_fsdp_test(compile_compute_on_module=Transformer)
|
|
def test_train_parity_multi_groups(self):
|
|
"""
|
|
Tests train parity against DDP when using multiple parameter groups for
|
|
communication (for communication and computation overlap plus memory
|
|
reduction).
|
|
"""
|
|
self.run_subtests(
|
|
{
|
|
"reshard_after_forward": [True, False],
|
|
"test_device_type": [device_type.type],
|
|
"offload_policy": [OffloadPolicy()],
|
|
"delay_after_forward": [False, True],
|
|
"delay_before_all_gather": [False, True],
|
|
"delay_before_reduce_scatter": [False, True],
|
|
"delay_before_optim": [False, True],
|
|
"unshard_async_op": [False],
|
|
},
|
|
self._test_train_parity_multi_group,
|
|
)
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@unittest.skipIf(TEST_HPU, "sleep kernel not supported on HPU")
|
|
def test_train_parity_multi_group_cpu_offload_eager(self):
|
|
"""
|
|
Tests train parity when using multiple parameter groups for
|
|
communication and CPU offloading.
|
|
"""
|
|
self.run_subtests(
|
|
{
|
|
"reshard_after_forward": [True], # save CI time
|
|
"offload_policy": [
|
|
CPUOffloadPolicy(pin_memory=True),
|
|
CPUOffloadPolicy(pin_memory=False),
|
|
],
|
|
"test_device_type": [device_type.type],
|
|
"delay_after_forward": [False, True],
|
|
"delay_before_all_gather": [False, True],
|
|
"delay_before_reduce_scatter": [False, True],
|
|
"delay_before_optim": [False, True],
|
|
"unshard_async_op": [False],
|
|
},
|
|
self._test_train_parity_multi_group,
|
|
)
|
|
|
|
def _test_train_parity_multi_group(
|
|
self,
|
|
reshard_after_forward: Union[bool, int],
|
|
offload_policy: OffloadPolicy,
|
|
test_device_type: str,
|
|
delay_after_forward: bool,
|
|
delay_before_all_gather: bool,
|
|
delay_before_reduce_scatter: bool,
|
|
delay_before_optim: bool,
|
|
unshard_async_op: bool,
|
|
):
|
|
# Only test individual delays or all four delays to save test time
|
|
if (
|
|
delay_after_forward
|
|
+ delay_before_all_gather
|
|
+ delay_before_reduce_scatter
|
|
+ delay_before_optim
|
|
in (2, 3)
|
|
):
|
|
return
|
|
assert test_device_type in ("cuda", "hpu", "xpu", "cpu"), f"{test_device_type}"
|
|
torch.manual_seed(42)
|
|
vocab_size = 1024
|
|
model_args = ModelArgs(
|
|
n_layers=3,
|
|
n_heads=4,
|
|
vocab_size=vocab_size,
|
|
max_seq_len=64,
|
|
dropout_p=0,
|
|
)
|
|
model = Transformer(model_args)
|
|
ref_model = copy.deepcopy(model).to(device_type)
|
|
|
|
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
|
mesh = init_device_mesh(
|
|
test_device_type,
|
|
(self.world_size, 1),
|
|
mesh_dim_names=("replicate", "shard"),
|
|
)
|
|
fully_shard_fn = functools.partial(
|
|
replicate,
|
|
device_mesh=mesh,
|
|
reshard_after_forward=reshard_after_forward,
|
|
offload_policy=offload_policy,
|
|
)
|
|
for module in model.modules():
|
|
if isinstance(module, TransformerBlock):
|
|
fully_shard_fn(module)
|
|
fully_shard_fn(model)
|
|
if unshard_async_op:
|
|
model._set_unshard_async_op(unshard_async_op)
|
|
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
|
|
|
delay_in_ms = 100
|
|
orig_all_gather = dist.all_gather_into_tensor
|
|
orig_reduce_scatter = dist.reduce_scatter_tensor
|
|
|
|
def delayed_all_gather(*args, **kwargs):
|
|
torch.get_device_module(device_type)._sleep(
|
|
int(delay_in_ms * get_cycles_per_ms())
|
|
)
|
|
return orig_all_gather(*args, **kwargs)
|
|
|
|
def delayed_reduce_scatter(*args, **kwargs):
|
|
torch.get_device_module(device_type)._sleep(
|
|
int(delay_in_ms * get_cycles_per_ms())
|
|
)
|
|
return orig_reduce_scatter(*args, **kwargs)
|
|
|
|
torch.manual_seed(42 + self.rank + 1)
|
|
patch_all_gather_ctx = (
|
|
patch_all_gather(delayed_all_gather)
|
|
if delay_before_all_gather
|
|
else contextlib.nullcontext()
|
|
)
|
|
patch_reduce_scatter_ctx = (
|
|
patch_reduce_scatter(delayed_reduce_scatter)
|
|
if delay_before_reduce_scatter
|
|
else contextlib.nullcontext()
|
|
)
|
|
with patch_all_gather_ctx, patch_reduce_scatter_ctx:
|
|
for iter_idx in range(10):
|
|
inp = torch.randint(0, vocab_size, (3, 64), device=device_type)
|
|
losses: list[torch.Tensor] = []
|
|
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
|
|
losses.append(_model(inp).sum())
|
|
if _model is model and delay_after_forward:
|
|
torch.get_device_module(device_type)._sleep(
|
|
int(delay_in_ms * get_cycles_per_ms())
|
|
)
|
|
losses[-1].backward()
|
|
if _model is model and delay_before_optim:
|
|
torch.get_device_module(device_type)._sleep(
|
|
int(delay_in_ms * get_cycles_per_ms())
|
|
)
|
|
|
|
for param in ref_model.parameters():
|
|
if param.grad is not None:
|
|
dist.all_reduce(param.grad)
|
|
param.grad.div_(self.world_size)
|
|
|
|
for _optim in (ref_optim, optim):
|
|
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
|
_optim.step()
|
|
self.assertEqual(losses[0], losses[1])
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_non_root_forward_backward(self):
|
|
"""
|
|
Tests running forward/backward through the root and then through a
|
|
non-root. The non-root needs to synchronize streams/queue the callback.
|
|
"""
|
|
torch.manual_seed(42)
|
|
lin_dim = 32
|
|
model = nn.Sequential(*[MLP(lin_dim, torch.device("cpu")) for _ in range(3)])
|
|
ref_model = copy.deepcopy(model).to(device_type)
|
|
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
|
for mlp in model:
|
|
replicate(mlp)
|
|
replicate(model)
|
|
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
|
|
torch.manual_seed(42 + self.rank)
|
|
inp = torch.randn((8, lin_dim), device=device_type)
|
|
|
|
ref_root_loss = ref_model(inp).sum()
|
|
ref_root_loss.backward()
|
|
for param in ref_model.parameters():
|
|
dist.all_reduce(param.grad)
|
|
param.grad.detach().div_(self.world_size)
|
|
ref_optim.step()
|
|
ref_optim.zero_grad()
|
|
ref_nonroot_loss = ref_model[0](inp).sum()
|
|
ref_nonroot_loss.backward()
|
|
for param in ref_model.parameters():
|
|
if param.grad is not None:
|
|
dist.all_reduce(param.grad)
|
|
param.grad.detach().div_(self.world_size)
|
|
ref_optim.step()
|
|
|
|
root_loss = model(inp).sum()
|
|
root_loss.backward()
|
|
torch.get_device_module(device_type)._sleep(int(100 * get_cycles_per_ms()))
|
|
optim.step()
|
|
optim.zero_grad()
|
|
nonroot_loss = model[0](inp).sum()
|
|
nonroot_loss.backward()
|
|
optim.step()
|
|
|
|
self.assertEqual(ref_root_loss, root_loss)
|
|
self.assertEqual(ref_nonroot_loss, nonroot_loss)
|
|
self.assertEqual(ref_model(inp).sum(), model(inp).sum())
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_multi_forward_module(self):
|
|
"""
|
|
Tests parity when running a module that participates multiple
|
|
times in forward.
|
|
"""
|
|
self.run_subtests(
|
|
{"reshard_after_forward": [True, False]},
|
|
self._test_multi_forward_module,
|
|
)
|
|
|
|
def _test_multi_forward_module(self, reshard_after_forward: Union[bool, int]):
|
|
class MultiForwardModule(nn.Module):
|
|
def __init__(self, device: torch.device):
|
|
super().__init__()
|
|
self.inner = nn.Linear(4, 4, device=device)
|
|
self.outer = nn.Linear(4, 5, device=device)
|
|
|
|
def forward(self, x):
|
|
i = self.inner(x)
|
|
j = self.inner(x)
|
|
return self.outer(i + j)
|
|
|
|
torch.manual_seed(42)
|
|
model = MultiForwardModule(device=device_type.type)
|
|
ref_model = copy.deepcopy(model).to(device_type)
|
|
|
|
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
|
replicate(model.inner)
|
|
replicate(model)
|
|
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
|
|
|
torch.manual_seed(42 + self.rank)
|
|
inp = torch.randn((32, 4), device=device_type.type)
|
|
for iter_idx in range(10):
|
|
losses: list[torch.Tensor] = []
|
|
for _model in (ref_model, model):
|
|
losses.append(_model(inp).sum())
|
|
losses[-1].backward()
|
|
|
|
for param in ref_model.parameters():
|
|
if param.grad is not None:
|
|
dist.all_reduce(param.grad)
|
|
param.grad.div_(self.world_size)
|
|
|
|
for _optim in (ref_optim, optim):
|
|
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
|
_optim.step()
|
|
|
|
self.assertEqual(losses[0], losses[1])
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_explicit_prefetching(self):
|
|
torch.manual_seed(42)
|
|
model_args = ModelArgs(n_layers=8, dropout_p=0.0)
|
|
model = Transformer(model_args)
|
|
ref_model = copy.deepcopy(model).to(device_type)
|
|
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
|
|
|
|
for layer in itertools.chain(model.layers, [model]):
|
|
replicate(layer)
|
|
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
|
|
|
|
num_to_forward_prefetch = num_to_backward_prefetch = 2
|
|
for i, layer in enumerate(model.layers):
|
|
if i >= len(model.layers) - num_to_forward_prefetch:
|
|
break
|
|
layers_to_prefetch = [
|
|
model.layers[i + j] for j in range(1, num_to_forward_prefetch + 1)
|
|
]
|
|
layer.set_modules_to_forward_prefetch(layers_to_prefetch)
|
|
for i, layer in enumerate(model.layers):
|
|
if i < num_to_backward_prefetch:
|
|
continue
|
|
layers_to_prefetch = [
|
|
model.layers[i - j] for j in range(1, num_to_backward_prefetch + 1)
|
|
]
|
|
layer.set_modules_to_backward_prefetch(layers_to_prefetch)
|
|
|
|
torch.manual_seed(42 + self.rank)
|
|
inp = torch.randint(0, model_args.vocab_size, (2, 8), device=device_type.type)
|
|
for _ in range(10):
|
|
losses: list[torch.Tensor] = []
|
|
|
|
for _model in (ref_model, model):
|
|
losses.append(_model(inp).sum())
|
|
losses[-1].backward()
|
|
|
|
for param in ref_model.parameters():
|
|
if param.grad is not None:
|
|
dist.all_reduce(param.grad)
|
|
param.grad.div_(self.world_size)
|
|
|
|
for _optim in (ref_optim, optim):
|
|
_optim.zero_grad()
|
|
_optim.step()
|
|
|
|
self.assertEqual(losses[0], losses[1])
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@unittest.skipIf(TEST_HPU, "Sleep is not supported on HPU")
|
|
def test_post_optim_event(self):
|
|
torch.manual_seed(42)
|
|
model_args = ModelArgs(dropout_p=0.0)
|
|
model = Transformer(model_args)
|
|
ref_model = copy.deepcopy(model).to(device_type.type)
|
|
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
|
|
for layer in itertools.chain(model.layers, [model]):
|
|
replicate(layer)
|
|
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
|
|
|
|
def step_post_hook(
|
|
fsdp_module: FSDPModule, opt: torch.optim.Optimizer, args, kwargs
|
|
) -> None:
|
|
post_optim_event = (
|
|
torch.get_device_module(device_type).current_stream().record_event()
|
|
)
|
|
fsdp_module.set_post_optim_event(post_optim_event)
|
|
|
|
optim.register_step_post_hook(functools.partial(step_post_hook, model))
|
|
|
|
torch.manual_seed(42 + self.rank)
|
|
inp = torch.randint(0, model_args.vocab_size, (2, 8), device=device_type.type)
|
|
# Track all losses and check for equality at the end to avoid a CPU
|
|
# sync point after each iteration
|
|
ref_losses: list[torch.Tensor] = []
|
|
losses: list[torch.Tensor] = []
|
|
for _ in range(10):
|
|
ref_optim.zero_grad()
|
|
ref_losses.append(ref_model(inp).sum())
|
|
ref_losses[-1].backward()
|
|
|
|
for param in ref_model.parameters():
|
|
if param.grad is not None:
|
|
dist.all_reduce(param.grad)
|
|
param.grad.div_(self.world_size)
|
|
|
|
ref_optim.step()
|
|
for _ in range(10):
|
|
optim.zero_grad()
|
|
losses.append(model(inp).sum())
|
|
losses[-1].backward()
|
|
optim.step()
|
|
# Sleep after the optimizer step to allow CPU to run ahead into the
|
|
# next iteration's forward, exercising the post-optim stream sync
|
|
torch.get_device_module(device_type)._sleep(int(25 * get_cycles_per_ms()))
|
|
for ref_loss, loss in zip(ref_losses, losses):
|
|
self.assertEqual(ref_loss, loss)
|
|
|
|
|
|
class TestReplicateTrainingCompose(FSDPTest):
|
|
@property
|
|
def world_size(self) -> int:
|
|
# Since these tests run with a larger transformer model, they may see
|
|
# some numeric drift with >2 GPUs
|
|
return min(torch.get_device_module(device_type).device_count(), 2)
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@compiled_fsdp_test(compile_compute_on_module=Transformer)
|
|
def test_train_parity_with_activation_checkpointing(self):
|
|
"""
|
|
Tests train parity against DDP when composing with activation
|
|
checkpointing.
|
|
"""
|
|
self.run_subtests(
|
|
{
|
|
"reshard_after_forward": [True, False],
|
|
"checkpoint_impl": ["composable", "utils", "wrapper"],
|
|
"module_grouping": ["block", "mem_eff", "mem_eff_weight_tied"],
|
|
"test_device_type": [device_type.type],
|
|
},
|
|
self._test_train_parity_with_activation_checkpointing,
|
|
)
|
|
|
|
def _test_train_parity_with_activation_checkpointing(
|
|
self,
|
|
reshard_after_forward: Union[bool, int],
|
|
checkpoint_impl: str,
|
|
module_grouping: str,
|
|
test_device_type: str,
|
|
):
|
|
assert checkpoint_impl in ("composable", "utils", "wrapper")
|
|
testing_compile = replicate != torch.distributed._composable.replicate_with_fsdp
|
|
if testing_compile and checkpoint_impl == "composable":
|
|
return
|
|
torch.manual_seed(42)
|
|
vocab_size = 1024
|
|
with torch.device(device_type):
|
|
model_args = ModelArgs(
|
|
n_layers=3,
|
|
n_heads=4,
|
|
vocab_size=vocab_size,
|
|
max_seq_len=64,
|
|
dropout_p=0,
|
|
checkpoint_activations=(checkpoint_impl == "utils"),
|
|
# For the mem-efficient module grouping, we separate the
|
|
# embeddings from the output projection, which does not support
|
|
# weight tying
|
|
weight_tying=module_grouping != "mem_eff",
|
|
)
|
|
model = Transformer(model_args)
|
|
ref_model = copy.deepcopy(model).to(device_type)
|
|
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
|
|
|
# Apply activation checkpointing
|
|
prefixes_to_ignore = ()
|
|
if checkpoint_impl == "wrapper":
|
|
prefixes_to_ignore = (_CHECKPOINT_PREFIX,)
|
|
apply_activation_checkpointing(
|
|
model, check_fn=lambda m: isinstance(m, TransformerBlock)
|
|
)
|
|
elif checkpoint_impl == "composable":
|
|
for module in model.modules():
|
|
if isinstance(module, TransformerBlock):
|
|
checkpoint(module)
|
|
|
|
# Apply Replicate
|
|
device_mesh = init_device_mesh(
|
|
test_device_type,
|
|
(self.world_size, 1),
|
|
mesh_dim_names=("replicate", "shard"),
|
|
)
|
|
fsdp_kwargs = {
|
|
"reshard_after_forward": reshard_after_forward,
|
|
"device_mesh": device_mesh,
|
|
}
|
|
if module_grouping == "mem_eff":
|
|
assert model_args.n_layers == 3
|
|
replicate(model.layers[0], **fsdp_kwargs)
|
|
replicate([model.layers[1], model.layers[2]], **fsdp_kwargs)
|
|
replicate([model.tok_embeddings, model.pos_embeddings], **fsdp_kwargs)
|
|
# Embedding weights are not needed for embedding backward
|
|
model.tok_embeddings.set_unshard_in_backward(False)
|
|
replicate([model.norm, model.output], **fsdp_kwargs)
|
|
elif module_grouping == "mem_eff_weight_tied":
|
|
replicate([model.tok_embeddings, model.output], **fsdp_kwargs)
|
|
for layer in model.layers:
|
|
replicate(layer, **fsdp_kwargs)
|
|
elif module_grouping == "block":
|
|
for layer in model.layers:
|
|
replicate(layer, **fsdp_kwargs)
|
|
else:
|
|
raise NotImplementedError(f"Unknown module grouping: {module_grouping}")
|
|
replicate(model, **fsdp_kwargs)
|
|
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
|
|
|
torch.manual_seed(42 + self.rank)
|
|
# Reuse the same input across iterations to avoid loss explosion from
|
|
# trying to learn from random inputs
|
|
inp = torch.randint(0, vocab_size, (3, 64), device=device_type.type)
|
|
check_sharded_parity(
|
|
self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore
|
|
)
|
|
for iter_idx in range(10):
|
|
losses: list[torch.Tensor] = []
|
|
for _model in (ref_model, model):
|
|
torch.manual_seed(iter_idx + 1) # for dropout determinism
|
|
losses.append(_model(inp).sum())
|
|
losses[-1].backward()
|
|
|
|
for param in ref_model.parameters():
|
|
if param.grad is not None:
|
|
dist.all_reduce(param.grad)
|
|
param.grad.div_(self.world_size)
|
|
|
|
if not testing_compile:
|
|
check_sharded_parity(
|
|
self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore
|
|
)
|
|
self.assertEqual(losses[0], losses[1])
|
|
for _optim in (ref_optim, optim):
|
|
_optim.step()
|
|
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
|
if not testing_compile:
|
|
check_sharded_parity(
|
|
self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore
|
|
)
|
|
|
|
|
|
class TestReplicateSharedParams(FSDPTest):
|
|
@property
|
|
def world_size(self) -> int:
|
|
return min(4, torch.get_device_module(device_type).device_count())
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_train_parity_with_shared_params(self):
|
|
self.run_subtests(
|
|
{
|
|
"reshard_after_forward": [False, True],
|
|
"use_activation_checkpointing": [False, True],
|
|
},
|
|
self._test_train_shared_params,
|
|
)
|
|
|
|
def _test_train_shared_params(
|
|
self,
|
|
reshard_after_forward: bool,
|
|
use_activation_checkpointing: bool,
|
|
):
|
|
torch.manual_seed(42)
|
|
model_args = ModelArgs(n_layers=3, dropout_p=0.0, weight_tying=True)
|
|
model = Transformer(model_args)
|
|
ref_model = copy.deepcopy(model).to(device_type)
|
|
|
|
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
|
for module in model.modules():
|
|
if isinstance(module, TransformerBlock):
|
|
if use_activation_checkpointing:
|
|
checkpoint(module)
|
|
replicate(module, reshard_after_forward=reshard_after_forward)
|
|
replicate(model, reshard_after_forward=reshard_after_forward)
|
|
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
|
|
|
torch.manual_seed(42 + self.rank + 1)
|
|
for iter_idx in range(10):
|
|
inp = torch.randint(
|
|
0, model_args.vocab_size, (2, 16), device=device_type.type
|
|
)
|
|
losses: list[torch.Tensor] = []
|
|
for _model in (ref_model, model):
|
|
losses.append(_model(inp).sum())
|
|
losses[-1].backward()
|
|
|
|
for param in ref_model.parameters():
|
|
if param.grad is not None:
|
|
dist.all_reduce(param.grad)
|
|
param.grad.div_(self.world_size)
|
|
|
|
for _optim in (ref_optim, optim):
|
|
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
|
_optim.step()
|
|
|
|
self.assertEqual(losses[0], losses[1])
|
|
|
|
|
|
class TestReplicateGradientAccumulation(FSDPTest):
|
|
@property
|
|
def world_size(self) -> int:
|
|
return min(4, torch.get_device_module(device_type).device_count())
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_gradient_accumulation(self):
|
|
"""
|
|
Tests gradient accumulation with/without gradient reduction and
|
|
with/without resharding after backward.
|
|
"""
|
|
|
|
shard_size, replicate_size = 1, self.world_size
|
|
meshes = init_device_mesh(
|
|
device_type.type,
|
|
(replicate_size, shard_size),
|
|
mesh_dim_names=("replicate", "shard"),
|
|
)
|
|
self.run_subtests(
|
|
{
|
|
"mesh": [meshes],
|
|
"reshard_after_forward": [True, False],
|
|
# "all": disable reduce-scatter for all modules
|
|
# "root_only": disable reduce-scatter for root's linear only
|
|
# "some_mlps": disable reduce-scatter for some MLPs
|
|
"mode": ["all", "root_only", "some_mlps"],
|
|
"reshard_after_backward": [False, True],
|
|
"offload_policy": [OffloadPolicy(), CPUOffloadPolicy()],
|
|
# For HSDP only:
|
|
# `True`: reduce-scatter only (no all-reduce) each microbatch
|
|
# until the last microbatch
|
|
# `False`: neither reduce-scatter nor all-reduce each
|
|
# microbatch until the last microbatch
|
|
"reduce_scatter_only": [False, True],
|
|
},
|
|
self._test_gradient_accumulation,
|
|
)
|
|
|
|
def _test_gradient_accumulation(
|
|
self,
|
|
mesh: DeviceMesh,
|
|
reshard_after_forward: Union[bool, int],
|
|
mode: str,
|
|
reshard_after_backward: bool,
|
|
offload_policy: OffloadPolicy,
|
|
reduce_scatter_only: bool, # for HSDP
|
|
):
|
|
if (
|
|
(
|
|
not reshard_after_backward
|
|
and (reshard_after_forward is not False or mode == "some_mlps")
|
|
)
|
|
or (
|
|
isinstance(offload_policy, CPUOffloadPolicy)
|
|
and reshard_after_forward is not True
|
|
)
|
|
or (
|
|
mesh.ndim != 2
|
|
) # may eventually need to change once decision on device mesh is made
|
|
):
|
|
return # skip since not common or applicable
|
|
|
|
torch.manual_seed(42)
|
|
batch_size, lin_dim, num_mlps, num_microbatches = (2, 32, 3, 3)
|
|
if mode == "some_mlps":
|
|
num_mlps_to_disable_reduce_scatter = 2
|
|
modules = [nn.Linear(lin_dim, lin_dim)]
|
|
modules.extend(MLP(lin_dim) for _ in range(num_mlps))
|
|
model = nn.Sequential(*modules)
|
|
ref_model = copy.deepcopy(model).to(device_type)
|
|
replicate_fn = functools.partial(
|
|
replicate,
|
|
device_mesh=mesh,
|
|
reshard_after_forward=reshard_after_forward,
|
|
offload_policy=offload_policy,
|
|
)
|
|
for mlp in model[1:]:
|
|
replicate_fn(mlp)
|
|
replicate_fn(model) # root gets the 1st linear
|
|
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
|
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
|
|
|
def set_grad_sync_flag(
|
|
module: nn.Module, is_last_microbatch: bool, recurse: bool = True
|
|
):
|
|
if reduce_scatter_only:
|
|
module.set_requires_all_reduce(is_last_microbatch, recurse=recurse)
|
|
else:
|
|
module.set_requires_gradient_sync(is_last_microbatch, recurse=recurse)
|
|
|
|
def set_backward_flags(_model: nn.Module, is_last_microbatch: bool):
|
|
if mode == "all":
|
|
set_grad_sync_flag(_model, is_last_microbatch)
|
|
if not reshard_after_backward:
|
|
_model.set_reshard_after_backward(is_last_microbatch)
|
|
elif mode == "some_mlps":
|
|
for mlp in model[1 : 1 + num_mlps_to_disable_reduce_scatter]:
|
|
set_grad_sync_flag(mlp, is_last_microbatch)
|
|
if not reshard_after_backward:
|
|
mlp.set_reshard_after_backward(is_last_microbatch)
|
|
elif mode == "root_only":
|
|
set_grad_sync_flag(model, is_last_microbatch, recurse=False)
|
|
if not reshard_after_backward:
|
|
model.set_reshard_after_backward(is_last_microbatch, recurse=False)
|
|
|
|
torch.manual_seed(42 + self.rank + 1)
|
|
for iter_idx in range(5):
|
|
comm_count_list = []
|
|
|
|
for microbatch_idx in range(num_microbatches):
|
|
is_last_microbatch = microbatch_idx == num_microbatches - 1
|
|
set_backward_flags(model, is_last_microbatch)
|
|
inp = torch.randn(batch_size, lin_dim, device=device_type.type)
|
|
losses: list[torch.Tensor] = []
|
|
for _model in (ref_model, model):
|
|
with CommDebugMode() as comm_mode:
|
|
losses.append(_model(inp).sum())
|
|
losses[-1].backward()
|
|
comm_count_list.append(comm_mode.get_comm_counts())
|
|
self.assertEqual(losses[0], losses[1])
|
|
|
|
comm_counts = defaultdict(int)
|
|
for comm_count_dict in comm_count_list:
|
|
for collective, count in comm_count_dict.items():
|
|
comm_counts[collective] += count
|
|
|
|
all_gather_count = comm_counts[c10d_ops._allgather_base_]
|
|
# reduce_scatter_count = comm_counts[c10d_ops._reduce_scatter_base_]
|
|
all_reduce_count = comm_counts[c10d_ops.allreduce_]
|
|
|
|
# Expect one reduce-scatter per MLP plus one for the root's linear
|
|
# on the last microbatch
|
|
# expected_reduce_scatter_count = 0
|
|
expected_all_reduce_count = num_mlps + 1
|
|
|
|
if mode == "some_mlps":
|
|
# Expect additional reduce-scatters for non-disabled MLPs and
|
|
# the root's linear
|
|
expected_all_reduce_count += (
|
|
num_mlps - num_mlps_to_disable_reduce_scatter + 1
|
|
) * (num_microbatches - 1)
|
|
elif mode == "root_only":
|
|
# Expect additional reduce-scatters for all MLPs
|
|
expected_all_reduce_count += (num_mlps) * (num_microbatches - 1)
|
|
|
|
# self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count)
|
|
self.assertEqual(all_reduce_count, expected_all_reduce_count)
|
|
|
|
# Expect one all-gather per MLP plus one for the root's linear in
|
|
# the first microbatch's forward
|
|
expected_all_gather_count = 0
|
|
|
|
self.assertEqual(all_gather_count, expected_all_gather_count)
|
|
|
|
for param in ref_model.parameters():
|
|
if param.grad is not None:
|
|
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
|
|
check_sharded_parity(self, ref_model, model)
|
|
for _optim in (optim, ref_optim):
|
|
_optim.step()
|
|
# When `set_to_none=False`, we are exercising mixing
|
|
# gradient accumulation with and without communication
|
|
_optim.zero_grad(set_to_none=(iter_idx % 2))
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_1f1b_microbatching(self):
|
|
self.run_subtests(
|
|
{
|
|
"use_explicit_unshard": [False, True],
|
|
"reshard_after_backward": [False, True],
|
|
},
|
|
self._test_1f1b_microbatching,
|
|
)
|
|
|
|
def _test_1f1b_microbatching(
|
|
self, use_explicit_unshard: bool, reshard_after_backward: bool
|
|
):
|
|
torch.manual_seed(42)
|
|
model_args = ModelArgs(dropout_p=0.0)
|
|
model = Transformer(model_args)
|
|
ref_model = copy.deepcopy(model).to(device_type)
|
|
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
|
|
for module in model.modules():
|
|
if isinstance(module, TransformerBlock):
|
|
replicate(module, reshard_after_forward=False)
|
|
replicate(model, reshard_after_forward=False)
|
|
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
|
|
|
|
num_microbatches = 3
|
|
local_batch_size = 2
|
|
torch.manual_seed(42 + self.rank + 1)
|
|
inps = [
|
|
torch.randint(
|
|
0,
|
|
model_args.vocab_size,
|
|
(local_batch_size, 16),
|
|
device=device_type.type,
|
|
)
|
|
for _ in range(num_microbatches)
|
|
]
|
|
|
|
# Before pipelining, we may prefer to issue all all-gathers ahead of
|
|
# time to increase overlap opportunity at no difference in parameter
|
|
# memory usage since we do not reshard after forward
|
|
if use_explicit_unshard:
|
|
for module in model.modules():
|
|
if isinstance(module, FSDPModule):
|
|
module.unshard(async_op=True)
|
|
|
|
# Emulate the 1f1b pipeline schedule and only reduce gradients on the
|
|
# last microbatch
|
|
losses: list[torch.Tensor] = []
|
|
ref_losses: list[torch.Tensor] = []
|
|
for inp_idx, inp in enumerate(inps):
|
|
is_last_microbatch = inp_idx == num_microbatches - 1
|
|
model.set_requires_gradient_sync(is_last_microbatch)
|
|
model.set_is_last_backward(is_last_microbatch)
|
|
if not reshard_after_backward:
|
|
model.set_reshard_after_backward(is_last_microbatch)
|
|
losses.append(model(inp).sum())
|
|
losses[-1].backward()
|
|
ref_losses.append(ref_model(inp).sum())
|
|
ref_losses[-1].backward()
|
|
for param in ref_model.parameters():
|
|
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
|
|
|
|
for loss, ref_loss in zip(losses, ref_losses):
|
|
self.assertEqual(loss, ref_loss)
|
|
optim.step()
|
|
ref_optim.step()
|
|
check_sharded_parity(self, ref_model, model)
|
|
|
|
|
|
class TestReplicateCustomForwardMethod(FSDPTest):
|
|
@property
|
|
def world_size(self) -> int:
|
|
return min(torch.get_device_module(device_type).device_count(), 2)
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_register_fsdp_forward_method(self):
|
|
class VisionTransformer(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.patch_proj = nn.Conv2d(3, 1024, kernel_size=14, stride=14)
|
|
|
|
def forward_features(self, imgs: torch.Tensor) -> torch.Tensor:
|
|
return self.patch_proj(imgs).flatten(2).transpose(1, 2)
|
|
|
|
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
|
|
return self.forward_features(imgs).sum(dim=1)
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.vit, self.projector = VisionTransformer(), nn.Linear(1024, 256)
|
|
|
|
def forward(self, imgs: torch.Tensor) -> torch.Tensor:
|
|
# Run `vit.forward_features`, which is not `forward`!
|
|
patch_embeddings = self.vit.forward_features(imgs)
|
|
return self.projector(patch_embeddings)
|
|
|
|
torch.manual_seed(42)
|
|
model = Model()
|
|
ref_model = copy.deepcopy(model).to(device_type)
|
|
replicate(model.vit)
|
|
replicate(model.projector)
|
|
replicate(model)
|
|
register_fsdp_forward_method(model.vit, "forward_features")
|
|
|
|
torch.manual_seed(42 + self.rank + 1)
|
|
inp = torch.randn(4, 3, 224, 224, device=device_type.type)
|
|
ref_loss = ref_model(inp).sum()
|
|
loss = model(inp).sum()
|
|
self.assertEqual(ref_loss, loss)
|
|
ref_loss.backward()
|
|
loss.backward()
|
|
for param in ref_model.parameters():
|
|
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
|
|
check_sharded_parity(self, ref_model, model)
|
|
|
|
|
|
class TestReplicateTPTraining(FSDPTest):
|
|
@property
|
|
def world_size(self) -> int:
|
|
return min(4, torch.get_device_module(device_type).device_count())
|
|
|
|
def init_global_mesh(self) -> DeviceMesh:
|
|
return init_device_mesh(
|
|
device_type.type,
|
|
(2, 1, 2),
|
|
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
|
|
)
|
|
|
|
@skip_if_lt_x_gpu(8)
|
|
def test_replicate_tp(self):
|
|
global_mesh = self.init_global_mesh()
|
|
self.run_subtests(
|
|
{
|
|
"reshard_after_forward": [False, True],
|
|
"use_activation_checkpointing": [False, True],
|
|
"mlp_dim": [3, 5, 16, 17],
|
|
"foreach": [False],
|
|
},
|
|
functools.partial(self._test_replicate_tp, global_mesh),
|
|
)
|
|
|
|
def _test_replicate_tp(
|
|
self,
|
|
global_mesh: DeviceMesh,
|
|
reshard_after_forward: bool,
|
|
use_activation_checkpointing: bool,
|
|
mlp_dim: int,
|
|
foreach: bool,
|
|
):
|
|
dp_mesh, tp_mesh = global_mesh["dp_replicate", "dp_shard"], global_mesh["tp"]
|
|
dp_pg = dp_mesh._flatten().get_group() # used for `replicate()`
|
|
|
|
torch.manual_seed(42)
|
|
model = MLPStack(mlp_dim)
|
|
ref_model = copy.deepcopy(model).to(device_type)
|
|
|
|
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=foreach)
|
|
|
|
parallelize_plan = {
|
|
# Pass `use_local_output=False` to keep as DTensor to preserve
|
|
# uneven activation dims
|
|
"0.in_proj": ColwiseParallel(use_local_output=False),
|
|
"0.out_proj": RowwiseParallel(use_local_output=False),
|
|
"1.in_proj": ColwiseParallel(use_local_output=False),
|
|
"1.out_proj": RowwiseParallel(use_local_output=False),
|
|
"2.in_proj": ColwiseParallel(use_local_output=False),
|
|
"2.out_proj": (RowwiseParallel()),
|
|
}
|
|
|
|
model = parallelize_module(model, tp_mesh, parallelize_plan)
|
|
|
|
for module in model:
|
|
if isinstance(module, nn.LayerNorm):
|
|
continue
|
|
if use_activation_checkpointing:
|
|
checkpoint(module)
|
|
replicate(module, device_mesh=dp_mesh)
|
|
replicate(model, device_mesh=dp_mesh)
|
|
|
|
# Checking parameters match orig model is critical to validate .full_tensor correctly replicates the
|
|
# strided-sharded layers.
|
|
for ref_p, p in zip(ref_model.parameters(), model.parameters()):
|
|
self.assertIsInstance(p, DTensor)
|
|
self.assertEqual(ref_p, p.full_tensor())
|
|
|
|
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach)
|
|
|
|
torch.manual_seed(42 + dp_pg.rank() + 1)
|
|
device = device_type
|
|
for iter_idx in range(10):
|
|
inp = torch.randn((8, mlp_dim), device=device)
|
|
losses: list[torch.Tensor] = []
|
|
for _model in (ref_model, model):
|
|
losses.append(_model(inp).sum())
|
|
losses[-1].backward()
|
|
|
|
for param in ref_model.parameters():
|
|
if param.grad is not None:
|
|
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
|
|
|
|
for _optim in (ref_optim, optim):
|
|
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
|
_optim.step()
|
|
self.assertEqual(losses[0], losses[1])
|
|
check_sharded_parity(self, ref_model, model)
|
|
|
|
for _, p in model.named_parameters():
|
|
self.assertIsInstance(p, DTensor)
|
|
self.assertEqual(p.device_mesh.ndim, 3)
|
|
self.assertEqual(len(p.placements), 3)
|
|
self.assertEqual(
|
|
p.device_mesh.mesh_dim_names, ("dp_replicate", "dp_shard", "tp")
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|