pytorch/test/distributed/tensor/parallel
Andrew Gu 85dc254364 [DTensor] Moved Transformer sharding to staticmethod (#121660)
To support FSDP + TP/SP unit tests, let us factor out the canonical TP/SP sharding of `Transformer` to a staticmethod that can be called by other unit tests.

Test Plan:
```
pytest test/distributed/tensor/parallel/test_tp_examples.py -k test_transformer_training
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121660
Approved by: https://github.com/wanchaol, https://github.com/yifuwang
ghstack dependencies: #121360, #121357
2024-03-12 15:08:57 +00:00
..
__init__.py
test_ddp_2d_parallel.py [tp] remove deprecated tp_mesh_dim arg (#121432) 2024-03-08 17:46:44 +00:00
test_fsdp_2d_parallel.py [DeviceMesh] Rename get_dim_groups to get_group (#114708) 2023-11-30 23:40:14 +00:00
test_parallelize_api.py [tp] remove deprecated tp_mesh_dim arg (#121432) 2024-03-08 17:46:44 +00:00
test_tp_examples.py [DTensor] Moved Transformer sharding to staticmethod (#121660) 2024-03-12 15:08:57 +00:00
test_tp_random_state.py deprecate PairwiseParallel from test (#114314) 2023-11-30 02:19:30 +00:00
test_tp_style.py [TP] Introduce Sequence Parallel Style for Laynorm/RMSNorm/Dropout (#121295) 2024-03-07 02:04:59 +00:00