[FSDP] Change _create_chunk_dtensor in fsdp/_shard_utils.py to use public API from DTensor (#110831)

This PR:
1) updates _create_chunk_dtensor() in _shard_utils.py to use public APIs from DTensor. This will avoid the global_size calculation error from using DTensor.from_local() for uneven-sharded parameters, as described in https://github.com/pytorch/pytorch/issues/110762
2) updates test/distributed/fsdp/test_fsdp_dtensor_state_dict.py to include unit test for a model with uneven sharding.

cc. @wanchaol, @fegin

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110831
Approved by: https://github.com/wanchaol, https://github.com/fegin
This commit is contained in:
wz337 2023-10-08 17:54:43 -07:00 committed by PyTorch MergeBot
parent 6e770c0dda
commit d9eb5a57aa
2 changed files with 74 additions and 38 deletions

View File

@ -46,9 +46,29 @@ class TestDummyModel(torch.nn.Module):
return torch.rand(8, 8, device="cuda")
class TestDummyModelUneven(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.net1 = nn.Sequential(nn.Linear(5, 10), nn.ReLU())
self.net2 = nn.Sequential(nn.Linear(10, 15), nn.ReLU())
self.net3 = nn.Linear(15, 30)
self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(30, 5))
def forward(self, x):
return self.net4(self.net3(self.net2(self.net1(x))))
def get_input(self):
return torch.rand(5, 5, device="cuda")
class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
def _create_model(self, device_mesh=None):
model = FSDP(TestDummyModel().cuda(), device_mesh=device_mesh)
def _create_model(self, is_even_sharded_model, device_mesh=None):
dummy_model = (
TestDummyModel() if is_even_sharded_model else TestDummyModelUneven()
)
model = FSDP(dummy_model.cuda(), device_mesh=device_mesh)
optim = torch.optim.Adam(model.parameters(), lr=0.1)
model(model.get_input()).sum().backward()
optim.step()
@ -57,9 +77,10 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(2)
def test_fsdp_init_with_device_mesh(self):
@parametrize("is_even_sharded_model", [True, False])
def test_fsdp_init_with_device_mesh(self, is_even_sharded_model):
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
model, optim = self._create_model(device_mesh)
model, optim = self._create_model(is_even_sharded_model, device_mesh)
FSDP.set_state_dict_type(
model,
@ -68,7 +89,7 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
state_dict = model.state_dict()
optim_state_dict = FSDP.optim_state_dict(model, optim)
for v in model.state_dict().values():
for v in state_dict.values():
self.assertEqual(type(v), DTensor)
self.assertEqual(len(v.placements), 1)
self.assertEqual(v.placements[0], (Shard(dim=0)))
@ -91,9 +112,12 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(2)
@parametrize("offload_to_cpu", [True, False])
def test_dtensor_sharded_tensor_state_dict_identical(self, offload_to_cpu):
@parametrize("is_even_sharded_model", [True, False])
def test_dtensor_sharded_tensor_state_dict_identical(
self, offload_to_cpu, is_even_sharded_model
):
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
model, optim = self._create_model(device_mesh)
model, optim = self._create_model(is_even_sharded_model, device_mesh)
FSDP.set_state_dict_type(
model,
@ -106,7 +130,7 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
dtensor_sd = model.state_dict()
dtensor_osd = FSDP.optim_state_dict(model, optim)
ref_model, ref_optim = self._create_model()
ref_model, ref_optim = self._create_model(is_even_sharded_model)
FSDP.set_state_dict_type(
ref_model,
StateDictType.SHARDED_STATE_DICT,
@ -126,12 +150,17 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
k2, v2 = sharded_tensor_sd
self.assertEqual(k1, k2)
self.assertEqual(type(v1), DTensor)
self.assertEqual(type(v2), ShardedTensor)
# check whether local_tensor are the same
self.assertEqual(v1.to_local(), v2.local_tensor())
# check whether device are the same
self.assertEqual(v1.to_local().device, v2.local_tensor().device)
# if the ShardedTensor is an empty shard,
# then the local tensor of DTensor should be local_tensor=tensor([])
if len(v2.local_shards()) == 0:
self.assertEqual(v1.to_local().numel(), 0)
else:
self.assertEqual(type(v1), DTensor)
self.assertEqual(type(v2), ShardedTensor)
# check whether local_tensor are the same
self.assertEqual(v1.to_local(), v2.local_tensor())
# check whether device are the same
self.assertEqual(v1.to_local().device, v2.local_tensor().device)
# Check dtensor and sharde_tensor optim state dict values are identical
for dtensor_osd_state, sharded_tensor_osd_state in zip(
@ -148,21 +177,29 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
self.assertEqual(k1, k2)
if k1 != "step":
self.assertEqual(type(v1), DTensor)
self.assertEqual(type(v2), ShardedTensor)
# check whether local_tensor are the same
self.assertEqual(v1.to_local(), v2.local_tensor())
# check whether device are the same
self.assertEqual(v1.to_local().device, v2.local_tensor().device)
# if the ShardedTensor is an empty shard,
# then the local tensor of DTensor should be local_tensor=tensor([])
if len(v2.local_shards()) == 0:
self.assertEqual(v1.to_local().numel(), 0)
else:
self.assertEqual(type(v1), DTensor)
self.assertEqual(type(v2), ShardedTensor)
# check whether local_tensor are the same
self.assertEqual(v1.to_local(), v2.local_tensor())
# check whether device are the same
self.assertEqual(v1.to_local().device, v2.local_tensor().device)
else:
self.assertEqual(v1, v2)
@with_comms
@skip_if_lt_x_gpu(2)
@parametrize("offload_to_cpu", [True, False])
def test_dtensor_sharded_optim_load_state_dict(self, offload_to_cpu):
@parametrize("is_even_sharded_model", [True, False])
def test_dtensor_sharded_optim_load_state_dict(
self, offload_to_cpu, is_even_sharded_model
):
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
model, optim = self._create_model(device_mesh)
model, optim = self._create_model(is_even_sharded_model, device_mesh)
FSDP.set_state_dict_type(
model,
@ -215,9 +252,12 @@ class TestFSDPWithDeviceMeshAndDTensor(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(2)
@parametrize("offload_to_cpu", [True, False])
def test_dtensor_sharded_model_load_state_dict(self, offload_to_cpu):
@parametrize("is_even_sharded_model", [True, False])
def test_dtensor_sharded_model_load_state_dict(
self, offload_to_cpu, is_even_sharded_model
):
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
model, optim = self._create_model(device_mesh)
model, optim = self._create_model(is_even_sharded_model, device_mesh)
FSDP.set_state_dict_type(
model,

View File

@ -177,21 +177,17 @@ def _create_chunk_dtensor(
Shard a tensor to chunks along the first dimension. The local rank will gets its
corresponding chunk as the local tensor to create a DTensor.
"""
inner_dim = device_mesh.ndim - 1
shard_placement = DShard(0)
tensor_list, _ = shard_placement._split_tensor(
tensor,
device_mesh.size(dim=inner_dim),
with_padding=False,
contiguous=True,
)
# We need to explicitly call .clone() here as tensor.chunks() splits a tensor into the specified number of chunks.
# Each chunk is a view of the input tensor. If the original tensor change, the view will also be changed.
# We need to explicitly call .detach() to return a new tensor detached from the current graph.
local_tensor = tensor_list[rank].clone().detach()
tensor = tensor.clone().detach()
# FSDP placements: [Shard(0)]
# HSDP placements: [Replicate(), Shard(0)]
placements = [Replicate() for _ in range(device_mesh.ndim)]
placements[-1] = shard_placement # type: ignore[call-overload]
return DTensor.from_local(local_tensor, device_mesh, placements)
replicate_placements = [Replicate() for _ in range(device_mesh.ndim)]
shard_placements = [Replicate() for _ in range(device_mesh.ndim)]
shard_placements[-1] = DShard(0) # type: ignore[call-overload]
shard_placements = tuple(shard_placements)
return DTensor.from_local(tensor, device_mesh, replicate_placements).redistribute(
device_mesh=device_mesh,
placements=shard_placements,
)