pytorch/torch/distributed/tensor/parallel
Tianyu Liu d2ad9aa2f2 [dtensor][tp] add a ParallelStyle PrepareModuleInputOutput (#150372)
Needed this class for because `parallelize_module` takes a dict, which doesn't allow `PrepareModuleInput` and `PrepareModuleOutput` to be applied at the same time.

The `PrepareModuleInputOutput` in this PR initializes two variables `prepare_module_input` and `prepare_module_output` and uses them to process module / inputs / outputs.

I had another implementation which put all code in `PrepareModuleInputOutput` and let `PrepareModuleInput` and `PrepareModuleOutput` inherit the monolithic `PrepareModuleInputOutput`. But it is
1. less cleaner
2. conceptually abusing inheritance because `PrepareModuleInput` shouldn't be able to access class methods of `PrepareModuleOutput` and vice versa

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150372
Approved by: https://github.com/wanchaol
2025-04-01 19:15:43 +00:00
..
__init__.py [dtensor][tp] add a ParallelStyle PrepareModuleInputOutput (#150372) 2025-04-01 19:15:43 +00:00
_data_parallel_utils.py Migrate from Tuple -> tuple in torch/distributed (#144258) 2025-01-10 08:34:54 +00:00
_utils.py Migrate from Tuple -> tuple in torch/distributed (#144258) 2025-01-10 08:34:54 +00:00
api.py PEP585 update - torch/distributed/tensor (#145141) 2025-01-18 20:01:59 +00:00
ddp.py PEP585 update - torch/distributed/tensor (#145141) 2025-01-18 20:01:59 +00:00
fsdp.py [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547) 2025-02-28 07:35:56 +00:00
input_reshard.py [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547) 2025-02-28 07:35:56 +00:00
loss.py [BE][Ez]: Remove extra copy in dtensor parallel loss (#148096) 2025-02-28 05:42:32 +00:00
style.py [dtensor][tp] add a ParallelStyle PrepareModuleInputOutput (#150372) 2025-04-01 19:15:43 +00:00