pytorch/test/distributed/_shard
IvanKobzarev bab4b5a341 [dist][sharded_tensor] Fix ChunkShardingSpec metadata offsets for empty shards (#121002)
ChunkShardingSpec generated metadata where offsets exceed the tensor size.

Example:

Torchrec prepared ShardedTensorMetadata:
```
ShardedTensorMetadata(shards_metadata=[
ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2, 512], placement=rank:0/cuda:0),
ShardMetadata(shard_offsets=[2, 0], shard_sizes=[2, 512], placement=rank:1/cuda:1),
ShardMetadata(shard_offsets=[4, 0], shard_sizes=[2, 512], placement=rank:2/cuda:2),
ShardMetadata(shard_offsets=[6, 0], shard_sizes=[2, 512], placement=rank:3/cuda:3),
ShardMetadata(shard_offsets=[8, 0], shard_sizes=[2, 512], placement=rank:4/cuda:4),
ShardMetadata(shard_offsets=[10, 0], shard_sizes=[0, 512], placement=rank:5/cuda:5),
ShardMetadata(shard_offsets=[10, 0], shard_sizes=[0, 512], placement=rank:6/cuda:6)
],
size=torch.Size([10, 512]
),
```
Calling ShardedTensor._init_from_local_shards_and_global_metadata()
ShardedTensor ShardingSpec builds metadata

```
ShardedTensorMetadata(shards_metadata=[
ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2, 512], placement=rank:0/cuda:0),
ShardMetadata(shard_offsets=[2, 0], shard_sizes=[2, 512], placement=rank:1/cuda:1),
ShardMetadata(shard_offsets=[4, 0], shard_sizes=[2, 512], placement=rank:2/cuda:2),
ShardMetadata(shard_offsets=[6, 0], shard_sizes=[2, 512], placement=rank:3/cuda:3),
ShardMetadata(shard_offsets=[8, 0], shard_sizes=[2, 512], placement=rank:4/cuda:4),
ShardMetadata(shard_offsets=[10, 0], shard_sizes=[0, 512], placement=rank:5/cuda:5),
ShardMetadata(shard_offsets=[12, 0], shard_sizes=[0, 512], placement=rank:6/cuda:6)
],
size=torch.Size([10, 512]), tensor_properties=TensorProperties(dtype=torch.float16, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))
```
The deduced ChunkShardingSpec:
```
ChunkShardingSpec(dim=0, placements=[rank:0/cuda:0, rank:1/cuda:1, rank:2/cuda:2, rank:3/cuda:3, rank:4/cuda:4, rank:5/cuda:5, rank:6/cuda:6])
```

The fix is to limit offsets by dim size.

Differential Revision: [D54419513](https://our.internmc.facebook.com/intern/diff/D54419513)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121002
Approved by: https://github.com/wz337
2024-03-02 08:58:48 +00:00
..
sharded_optim
sharded_tensor [dist][sharded_tensor] Fix ChunkShardingSpec metadata offsets for empty shards (#121002) 2024-03-02 08:58:48 +00:00
sharding_plan
sharding_spec Refactor some tests by using TEST_CUDA & TEST_MULTIGPU instead (#116083) 2024-01-03 08:53:59 +00:00
test_sharder.py Add call to run_tests for a few tests (#115097) 2023-12-07 08:27:40 +00:00