mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
ChunkShardingSpec generated metadata where offsets exceed the tensor size. Example: Torchrec prepared ShardedTensorMetadata: ``` ShardedTensorMetadata(shards_metadata=[ ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2, 512], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2, 0], shard_sizes=[2, 512], placement=rank:1/cuda:1), ShardMetadata(shard_offsets=[4, 0], shard_sizes=[2, 512], placement=rank:2/cuda:2), ShardMetadata(shard_offsets=[6, 0], shard_sizes=[2, 512], placement=rank:3/cuda:3), ShardMetadata(shard_offsets=[8, 0], shard_sizes=[2, 512], placement=rank:4/cuda:4), ShardMetadata(shard_offsets=[10, 0], shard_sizes=[0, 512], placement=rank:5/cuda:5), ShardMetadata(shard_offsets=[10, 0], shard_sizes=[0, 512], placement=rank:6/cuda:6) ], size=torch.Size([10, 512] ), ``` Calling ShardedTensor._init_from_local_shards_and_global_metadata() ShardedTensor ShardingSpec builds metadata ``` ShardedTensorMetadata(shards_metadata=[ ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2, 512], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2, 0], shard_sizes=[2, 512], placement=rank:1/cuda:1), ShardMetadata(shard_offsets=[4, 0], shard_sizes=[2, 512], placement=rank:2/cuda:2), ShardMetadata(shard_offsets=[6, 0], shard_sizes=[2, 512], placement=rank:3/cuda:3), ShardMetadata(shard_offsets=[8, 0], shard_sizes=[2, 512], placement=rank:4/cuda:4), ShardMetadata(shard_offsets=[10, 0], shard_sizes=[0, 512], placement=rank:5/cuda:5), ShardMetadata(shard_offsets=[12, 0], shard_sizes=[0, 512], placement=rank:6/cuda:6) ], size=torch.Size([10, 512]), tensor_properties=TensorProperties(dtype=torch.float16, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False)) ``` The deduced ChunkShardingSpec: ``` ChunkShardingSpec(dim=0, placements=[rank:0/cuda:0, rank:1/cuda:1, rank:2/cuda:2, rank:3/cuda:3, rank:4/cuda:4, rank:5/cuda:5, rank:6/cuda:6]) ``` The fix is to limit offsets by dim size. Differential Revision: [D54419513](https://our.internmc.facebook.com/intern/diff/D54419513) Pull Request resolved: https://github.com/pytorch/pytorch/pull/121002 Approved by: https://github.com/wz337 |
||
|---|---|---|
| .. | ||
| sharded_optim | ||
| sharded_tensor | ||
| sharding_plan | ||
| sharding_spec | ||
| test_sharder.py | ||