Revert "[FSDP2] Added test for N-way TP and 1-way FSDP with CPU offloading (#127024)"

This reverts commit 9117779b0a.

Reverted https://github.com/pytorch/pytorch/pull/127024 on behalf of https://github.com/atalman due to failing in CI ([comment](https://github.com/pytorch/pytorch/pull/127024#issuecomment-2133566325))
This commit is contained in:
PyTorch MergeBot 2024-05-27 14:12:09 +00:00
parent 7121ea6f70
commit c7f6fbfa9d
2 changed files with 6 additions and 63 deletions

View File

@ -56,7 +56,6 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
c10d_ops = torch.ops.c10d
funcol = torch.ops.c10d_functional
class TestFullyShardForwardInputs(FSDPTestMultiThread):
@ -928,10 +927,7 @@ class TestFullyShard2DTraining(FSDPTest):
replicate(ref_model, device_ids=[self.rank], process_group=dp_pg)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
model.parallelize(
tp_mesh,
dp_mesh,
use_activation_checkpointing,
reshard_after_forward=reshard_after_forward,
tp_mesh, dp_mesh, use_activation_checkpointing, reshard_after_forward
)
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
@ -947,61 +943,6 @@ class TestFullyShard2DTraining(FSDPTest):
_optim.step()
self.assertEqual(losses[0], losses[1])
@skip_if_lt_x_gpu(2)
def test_tp_with_fsdp_offloading(self):
global_mesh = init_device_mesh(
"cuda", (1, self.world_size), mesh_dim_names=("dp", "tp")
)
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
torch.manual_seed(42)
mlp_dim = 16
model = MLPStack(mlp_dim)
ref_model = copy.deepcopy(model).cuda()
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
# Parallelize with N-way TP and 1-way FSDP
model.parallelize(
tp_mesh,
dp_mesh,
use_activation_checkpointing=False,
reshard_after_forward=True,
offload_policy=CPUOffloadPolicy(),
)
for param in model.parameters():
self.assertEqual(param.device.type, "cpu")
num_mlps = sum(isinstance(module, MLP) for module in model.modules())
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
# NOTE: We still see the FSDP all-gather/reduce-scatter c10d ops
# called, but they will just be no-ops without issuing any kernels.
# We prefer to keep the no-op check at the c10d level, not in FSDP.
inp = torch.randn((4, mlp_dim), device="cuda") # same on all ranks
for iter_idx in range(10):
ref_optim.zero_grad()
optim.zero_grad()
with CommDebugMode() as fwd_comm_mode:
loss = model(inp).sum()
fwd_comm_counts = fwd_comm_mode.get_comm_counts()
self.assertEqual(len(fwd_comm_counts), 2)
self.assertEqual(fwd_comm_counts[funcol.all_reduce], num_mlps)
self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], num_mlps)
ref_loss = ref_model(inp).sum()
self.assertEqual(loss, ref_loss)
with CommDebugMode() as bwd_comm_mode:
loss.backward()
bwd_comm_counts = bwd_comm_mode.get_comm_counts()
self.assertEqual(len(bwd_comm_counts), 3)
# First MLP's input gradient does not need to be all-reduced
self.assertEqual(bwd_comm_counts[funcol.all_reduce], num_mlps - 1)
self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_mlps)
self.assertEqual(bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_mlps)
ref_loss.backward()
optim.step()
ref_optim.step()
@skip_if_lt_x_gpu(2)
@with_temp_dir
def test_train_parity_2d_transformer_checkpoint_resume(self):

View File

@ -893,7 +893,7 @@ class MLPStack(nn.Sequential):
tp_mesh: DeviceMesh,
dp_mesh: DeviceMesh,
use_activation_checkpointing: bool,
**fsdp_kwargs,
reshard_after_forward: bool,
) -> "MLPStack":
parallelize_plan = {
# Pass `use_local_output=False` to keep as DTensor to preserve
@ -915,8 +915,10 @@ class MLPStack(nn.Sequential):
continue
if use_activation_checkpointing:
checkpoint(module)
fully_shard(module, mesh=dp_mesh, **fsdp_kwargs)
fully_shard(self, mesh=dp_mesh, **fsdp_kwargs)
fully_shard(
module, mesh=dp_mesh, reshard_after_forward=reshard_after_forward
)
fully_shard(self, mesh=dp_mesh, reshard_after_forward=reshard_after_forward)
return self