mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Needed this class for because `parallelize_module` takes a dict, which doesn't allow `PrepareModuleInput` and `PrepareModuleOutput` to be applied at the same time. The `PrepareModuleInputOutput` in this PR initializes two variables `prepare_module_input` and `prepare_module_output` and uses them to process module / inputs / outputs. I had another implementation which put all code in `PrepareModuleInputOutput` and let `PrepareModuleInput` and `PrepareModuleOutput` inherit the monolithic `PrepareModuleInputOutput`. But it is 1. less cleaner 2. conceptually abusing inheritance because `PrepareModuleInput` shouldn't be able to access class methods of `PrepareModuleOutput` and vice versa Pull Request resolved: https://github.com/pytorch/pytorch/pull/150372 Approved by: https://github.com/wanchaol |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| _data_parallel_utils.py | ||
| _utils.py | ||
| api.py | ||
| ddp.py | ||
| fsdp.py | ||
| input_reshard.py | ||
| loss.py | ||
| style.py | ||