pytorch/torch/distributed/tensor/parallel
fduwjj 25a2845d78 [TP] Enable embedding sharding in TP API (#111177)
We see use cases where embedding sharding is also needed in TP API so we enabled it in the API since DTensor already support colwise embedding sharding.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111177
Approved by: https://github.com/wanchaol
ghstack dependencies: #111160, #111166, #111176
2023-10-15 11:49:56 +00:00
..
__init__.py [TP] Add prepareInput and output for input/output DTensor layout annotation in the parent module in TP API (#111166) 2023-10-14 15:37:52 +00:00
_data_parallel_utils.py [3/N][2D] Enable training with new 2D flow (#110034) 2023-09-26 09:14:15 +00:00
_utils.py [TP] Refactor Parallel Style to make it more usable (#111160) 2023-10-14 15:26:36 +00:00
_view_with_dim_change.py [TP][DTensor Perf] Some perf improvement to reduce DTensor CPU overhead (#106524) 2023-08-14 20:03:19 +00:00
api.py [TP] Enable embedding sharding in TP API (#111177) 2023-10-15 11:49:56 +00:00
ddp.py [2D][TP] Enable DDP TP integration with unit test (#106583) 2023-08-17 02:54:17 +00:00
fsdp.py [2D] Enable 2D FSDP+TP model.load_state_dict() (#110925) 2023-10-11 18:22:20 +00:00
input_reshard.py [Reland] Update mypy to 1.4.1 (#105227) 2023-07-15 20:30:20 +00:00
style.py [TP] Enable embedding sharding in TP API (#111177) 2023-10-15 11:49:56 +00:00