mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Add two unit tests: 1. HSDP checkpoint unit test 2. HSDP FSDP checkpoint conversion unit test Pull Request resolved: https://github.com/pytorch/pytorch/pull/111399 Approved by: https://github.com/wanchaol
213 lines
7.4 KiB
Python
213 lines
7.4 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
from copy import deepcopy
|
|
|
|
import torch
|
|
import torch.distributed.checkpoint as dist_cp
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.distributed._tensor import init_device_mesh, Replicate
|
|
|
|
from torch.distributed.checkpoint.default_planner import (
|
|
DefaultLoadPlanner,
|
|
DefaultSavePlanner,
|
|
)
|
|
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
|
ShardingStrategy,
|
|
StateDictType,
|
|
)
|
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
)
|
|
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorTestBase,
|
|
with_comms,
|
|
)
|
|
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
|
|
|
|
|
class SimpleModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.net1 = nn.Linear(5, 8)
|
|
self.relu = nn.ReLU()
|
|
self.net2 = nn.Linear(8, 4)
|
|
self.net3 = nn.Linear(4, 12)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.net1(x))
|
|
x = F.relu(self.net2(x))
|
|
x = F.relu(self.net3(x))
|
|
return x
|
|
|
|
def get_input(self):
|
|
return torch.rand(4, 5, device="cuda")
|
|
|
|
|
|
class SimpleModelUneven(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.net1 = nn.Linear(5, 10)
|
|
self.relu = nn.ReLU()
|
|
self.net2 = nn.Linear(10, 15)
|
|
self.net3 = nn.Linear(15, 30)
|
|
self.net4 = nn.Linear(30, 5)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.net1(x))
|
|
x = F.relu(self.net2(x))
|
|
x = F.relu(self.net3(x))
|
|
x = F.relu(self.net4(x))
|
|
return x
|
|
|
|
def get_input(self):
|
|
return torch.rand(4, 5, device="cuda")
|
|
|
|
|
|
class TestHSDPCheckpoint(DTensorTestBase):
|
|
@property
|
|
def backend(self):
|
|
return "cpu:gloo,cuda:nccl"
|
|
|
|
@with_comms
|
|
@skip_if_lt_x_gpu(4)
|
|
@with_temp_dir
|
|
@parametrize("is_even_sharded_model", [True, False])
|
|
def test_hsdp_checkpoint(self, is_even_sharded_model) -> None:
|
|
CHECKPOINT_DIR = self.temp_dir
|
|
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
|
|
|
|
mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
|
|
model = FSDP(
|
|
simple_model().cuda(),
|
|
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
|
|
device_mesh=mesh_2d,
|
|
)
|
|
optim = torch.optim.Adam(model.parameters(), lr=0.1)
|
|
|
|
FSDP.set_state_dict_type(
|
|
model,
|
|
StateDictType.SHARDED_STATE_DICT,
|
|
)
|
|
state_dict = {"model": model.state_dict()}
|
|
state_dict_to_save = deepcopy(state_dict)
|
|
|
|
dist_cp.save_state_dict(
|
|
state_dict=state_dict_to_save,
|
|
storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
|
|
planner=DefaultSavePlanner(),
|
|
)
|
|
|
|
# Update the parameters so current model state_dict now be different from state_dict_to_save.
|
|
model(model.get_input()).sum().backward()
|
|
optim.step()
|
|
|
|
# At this point, the current state dict is different from state_dict_to_save.
|
|
for (k1, v1), (k2, v2) in zip(
|
|
state_dict_to_save["model"].items(), model.state_dict().items()
|
|
):
|
|
self.assertEqual(k1, k2)
|
|
self.assertEqual(v1.device_mesh, v2.device_mesh)
|
|
self.assertEqual(v1.placements, v2.placements)
|
|
self.assertNotEqual(v1.to_local(), v2.to_local())
|
|
|
|
dist_cp.load_state_dict(
|
|
state_dict=state_dict_to_save,
|
|
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
|
|
planner=DefaultLoadPlanner(),
|
|
)
|
|
model.load_state_dict(state_dict_to_save["model"])
|
|
|
|
state_dict_after_load = model.state_dict()
|
|
# After loading, the current model state dict should be the same as state_dict_to_save.
|
|
for (k1, v1), (k2, v2) in zip(
|
|
state_dict_to_save["model"].items(), model.state_dict().items()
|
|
):
|
|
self.assertEqual(k1, k2)
|
|
self.assertEqual(v1.device_mesh, v2.device_mesh)
|
|
self.assertEqual(v1.placements, v2.placements)
|
|
self.assertEqual(v1.to_local(), v2.to_local())
|
|
|
|
@with_comms
|
|
@skip_if_lt_x_gpu(4)
|
|
@with_temp_dir
|
|
@parametrize("is_even_sharded_model", [True, False])
|
|
def test_hsdp_fsdp_checkpoint_conversion(self, is_even_sharded_model) -> None:
|
|
CHECKPOINT_DIR = self.temp_dir
|
|
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
|
|
|
|
# save the hsdp model state_dict
|
|
mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
|
|
hsdp_model = FSDP(
|
|
simple_model().cuda(),
|
|
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
|
|
device_mesh=mesh_2d,
|
|
)
|
|
FSDP.set_state_dict_type(
|
|
hsdp_model,
|
|
StateDictType.SHARDED_STATE_DICT,
|
|
)
|
|
hsdp_state_dict = {"model": hsdp_model.state_dict()}
|
|
dist_cp.save_state_dict(
|
|
state_dict=hsdp_state_dict,
|
|
storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
|
|
planner=DefaultSavePlanner(),
|
|
)
|
|
|
|
# initialize a fsdp model to load checkpoint into
|
|
mesh_1d = init_device_mesh(self.device_type, (self.world_size,))
|
|
fsdp_model = FSDP(
|
|
simple_model().cuda(),
|
|
device_mesh=mesh_1d,
|
|
)
|
|
FSDP.set_state_dict_type(
|
|
fsdp_model,
|
|
StateDictType.SHARDED_STATE_DICT,
|
|
)
|
|
fsdp_state_dict = {"model": fsdp_model.state_dict()}
|
|
|
|
# at this point, the hsdp model parameters are different from fsdp model parameters.
|
|
for (k1, v1), (k2, v2) in zip(
|
|
hsdp_state_dict["model"].items(), fsdp_state_dict["model"].items()
|
|
):
|
|
self.assertEqual(k1, k2)
|
|
self.assertNotEqual(v1.device_mesh, v2.device_mesh)
|
|
self.assertNotEqual(v1.placements, v2.placements)
|
|
v1_all_gather = v1.redistribute(
|
|
mesh_2d, placements=(Replicate(), Replicate())
|
|
)
|
|
v2_all_gather = v2.redistribute(mesh_1d, placements=(Replicate(),))
|
|
self.assertNotEqual(v1_all_gather.to_local(), v2_all_gather.to_local())
|
|
|
|
# load the fsdp state_dict from storage
|
|
dist_cp.load_state_dict(
|
|
state_dict=fsdp_state_dict,
|
|
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
|
|
planner=DefaultLoadPlanner(),
|
|
)
|
|
fsdp_model.load_state_dict(fsdp_state_dict["model"])
|
|
|
|
state_dict_after_load = fsdp_model.state_dict()
|
|
# After loading, the current model state dict should be the same as hsdp_state_dict.
|
|
for (k1, v1), (k2, v2) in zip(
|
|
hsdp_state_dict["model"].items(), state_dict_after_load.items()
|
|
):
|
|
self.assertEqual(k1, k2)
|
|
self.assertNotEqual(v1.device_mesh, v2.device_mesh)
|
|
self.assertNotEqual(v1.placements, v2.placements)
|
|
v1_all_gather = v1.redistribute(
|
|
mesh_2d, placements=(Replicate(), Replicate())
|
|
)
|
|
v2_all_gather = v2.redistribute(mesh_1d, placements=(Replicate(),))
|
|
self.assertEqual(v1_all_gather.to_local(), v2_all_gather.to_local())
|
|
|
|
|
|
instantiate_parametrized_tests(TestHSDPCheckpoint)
|
|
if __name__ == "__main__":
|
|
run_tests()
|