mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[FSDP2] Added test for N-way TP and 1-way FSDP with CPU offloading (#127024)"
This reverts commit 9117779b0a.
Reverted https://github.com/pytorch/pytorch/pull/127024 on behalf of https://github.com/atalman due to failing in CI ([comment](https://github.com/pytorch/pytorch/pull/127024#issuecomment-2133566325))
This commit is contained in:
parent
7121ea6f70
commit
c7f6fbfa9d
|
|
@ -56,7 +56,6 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
|
||||
c10d_ops = torch.ops.c10d
|
||||
funcol = torch.ops.c10d_functional
|
||||
|
||||
|
||||
class TestFullyShardForwardInputs(FSDPTestMultiThread):
|
||||
|
|
@ -928,10 +927,7 @@ class TestFullyShard2DTraining(FSDPTest):
|
|||
replicate(ref_model, device_ids=[self.rank], process_group=dp_pg)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
|
||||
model.parallelize(
|
||||
tp_mesh,
|
||||
dp_mesh,
|
||||
use_activation_checkpointing,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
tp_mesh, dp_mesh, use_activation_checkpointing, reshard_after_forward
|
||||
)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
|
||||
|
||||
|
|
@ -947,61 +943,6 @@ class TestFullyShard2DTraining(FSDPTest):
|
|||
_optim.step()
|
||||
self.assertEqual(losses[0], losses[1])
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_tp_with_fsdp_offloading(self):
|
||||
global_mesh = init_device_mesh(
|
||||
"cuda", (1, self.world_size), mesh_dim_names=("dp", "tp")
|
||||
)
|
||||
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
|
||||
torch.manual_seed(42)
|
||||
mlp_dim = 16
|
||||
model = MLPStack(mlp_dim)
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
|
||||
# Parallelize with N-way TP and 1-way FSDP
|
||||
model.parallelize(
|
||||
tp_mesh,
|
||||
dp_mesh,
|
||||
use_activation_checkpointing=False,
|
||||
reshard_after_forward=True,
|
||||
offload_policy=CPUOffloadPolicy(),
|
||||
)
|
||||
for param in model.parameters():
|
||||
self.assertEqual(param.device.type, "cpu")
|
||||
num_mlps = sum(isinstance(module, MLP) for module in model.modules())
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
|
||||
|
||||
# NOTE: We still see the FSDP all-gather/reduce-scatter c10d ops
|
||||
# called, but they will just be no-ops without issuing any kernels.
|
||||
# We prefer to keep the no-op check at the c10d level, not in FSDP.
|
||||
inp = torch.randn((4, mlp_dim), device="cuda") # same on all ranks
|
||||
for iter_idx in range(10):
|
||||
ref_optim.zero_grad()
|
||||
optim.zero_grad()
|
||||
|
||||
with CommDebugMode() as fwd_comm_mode:
|
||||
loss = model(inp).sum()
|
||||
|
||||
fwd_comm_counts = fwd_comm_mode.get_comm_counts()
|
||||
self.assertEqual(len(fwd_comm_counts), 2)
|
||||
self.assertEqual(fwd_comm_counts[funcol.all_reduce], num_mlps)
|
||||
self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], num_mlps)
|
||||
ref_loss = ref_model(inp).sum()
|
||||
self.assertEqual(loss, ref_loss)
|
||||
|
||||
with CommDebugMode() as bwd_comm_mode:
|
||||
loss.backward()
|
||||
bwd_comm_counts = bwd_comm_mode.get_comm_counts()
|
||||
self.assertEqual(len(bwd_comm_counts), 3)
|
||||
# First MLP's input gradient does not need to be all-reduced
|
||||
self.assertEqual(bwd_comm_counts[funcol.all_reduce], num_mlps - 1)
|
||||
self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_mlps)
|
||||
self.assertEqual(bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_mlps)
|
||||
ref_loss.backward()
|
||||
|
||||
optim.step()
|
||||
ref_optim.step()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@with_temp_dir
|
||||
def test_train_parity_2d_transformer_checkpoint_resume(self):
|
||||
|
|
|
|||
|
|
@ -893,7 +893,7 @@ class MLPStack(nn.Sequential):
|
|||
tp_mesh: DeviceMesh,
|
||||
dp_mesh: DeviceMesh,
|
||||
use_activation_checkpointing: bool,
|
||||
**fsdp_kwargs,
|
||||
reshard_after_forward: bool,
|
||||
) -> "MLPStack":
|
||||
parallelize_plan = {
|
||||
# Pass `use_local_output=False` to keep as DTensor to preserve
|
||||
|
|
@ -915,8 +915,10 @@ class MLPStack(nn.Sequential):
|
|||
continue
|
||||
if use_activation_checkpointing:
|
||||
checkpoint(module)
|
||||
fully_shard(module, mesh=dp_mesh, **fsdp_kwargs)
|
||||
fully_shard(self, mesh=dp_mesh, **fsdp_kwargs)
|
||||
fully_shard(
|
||||
module, mesh=dp_mesh, reshard_after_forward=reshard_after_forward
|
||||
)
|
||||
fully_shard(self, mesh=dp_mesh, reshard_after_forward=reshard_after_forward)
|
||||
return self
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user