pytorch/torch/distributed/tensor/parallel
Tianyu Liu 7b0e10f0e5 fix _MaskPartial when multiple embeddings coexist (#131264)
Previously, using _MaskPartial when multiple embeddings have the following issues:
1. Suppose an `nn.Embedding` has shape `[vocab_size, emb_size]`. When there are more than one embeddings, sharing the same `vocab_size` but with different `emb_size`s. Then they would not share `OpStrategy` since each, when involved in computation, would have different `OpSchema`; however, there would be cache hit for redistribute (specifically `_gen_transform_infos` in `torch/distributed/_tensor/_redistribute.py` when doing `Replicate` -> `_MaskPartial`) as the `_MaskPartial` only has `vocab_size` as `logical_dim_size` but not `emb_size` as attribute. This cache hit is undesirable and would cause trouble when doing all-reduce/reduce-scatter on the new `_MaskPartial` in a separate `OpStrategy`. The error was reported in #130725. In this PR, we introduce `offset_shape` to represent the embedding's full shape to avoid cache hit from embeddings of different shapes.
2. The second issue is when we have two `nn.Embedding`s `emb1` and `emb2` with the same shape. There will be cache hit not only in `_gen_transform_infos`, but also in `OpStrategy` generation. Previously, if we sequentially do `Replicate` -> `_MaskPartial` for both `emb1` `emb2` and then sequentially do reduction on the `_MaskPartial` of `emb1`, it would destroy the `MaskBuffer` and `emb2` would hit error. This PR adds a `refcount` for the `MaskBuffer` so that it can be properly shared by multiple `nn.Embedding`s.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131264
Approved by: https://github.com/wanchaol
2024-07-29 00:40:58 +00:00
..
__init__.py [BE][Easy] enable UFMT for torch/distributed/{tensor,_tensor}/ (#128868) 2024-06-18 21:49:02 +00:00
_data_parallel_utils.py [reland] pass shape/stride during tensor unflatten (#117340) 2024-01-13 19:33:47 +00:00
_utils.py [BE][Easy] enable UFMT for torch/distributed/{tensor,_tensor}/ (#128868) 2024-06-18 21:49:02 +00:00
api.py [BE][Easy] enable UFMT for torch/distributed/{tensor,_tensor}/ (#128868) 2024-06-18 21:49:02 +00:00
ddp.py [BE][Easy] enable UFMT for torch/distributed/{tensor,_tensor}/ (#128868) 2024-06-18 21:49:02 +00:00
fsdp.py [BE][Easy] enable UFMT for torch/distributed/{tensor,_tensor}/ (#128868) 2024-06-18 21:49:02 +00:00
input_reshard.py [BE][Easy] enable UFMT for torch/distributed/{tensor,_tensor}/ (#128868) 2024-06-18 21:49:02 +00:00
loss.py fix _MaskPartial when multiple embeddings coexist (#131264) 2024-07-29 00:40:58 +00:00
style.py [tp] improve SequenceParallel and its documentation (#131346) 2024-07-23 03:57:01 +00:00