Commit Graph

6 Commits

Author SHA1 Message Date
joncrall
ad782ff7df Enable xdoctest runner in CI for real this time (#83816)
Builds on #83317 and enables running the doctests. Just need to figure out what is causing the failures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83816
Approved by: https://github.com/ezyang, https://github.com/malfet
2022-12-29 05:32:42 +00:00
Chien-Chin Huang
d08e3d2304 [Composable API] Apply ufmt to _composable and the corresponding test folders (#91255)
This PR apply ufmt to format `_composable` related code. This is a request from https://github.com/pytorch/pytorch/pull/91234 to separate formatting changes as a new PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91255
Approved by: https://github.com/awgu
2022-12-23 16:08:27 +00:00
Shen Li
a0554261a1 Restore RNG states for composable reentrant activation checkpointing (#91265)
This allows ops like randperm to behave the same during re-computation.

Differential Revision: [D42196758](https://our.internmc.facebook.com/intern/diff/D42196758/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91265
Approved by: https://github.com/awgu
2022-12-22 03:15:55 +00:00
Chien-Chin Huang
d52f121dba [Composable API]Common _State parent class for composable and wrapper FSDP (#89147)
**Why this PR?**

For the composable APIs implementation, sometimes the internal APIs may not have the application (FSDP, DDP) root module but only the local module. One example is the state_dict/optimizer_state_dict implementation of FSDP. These APIs  are designed to start with the root module of the model. It is tricky for these APIs to tell whether a random submodule is managed by either DDP or FSDP.

It will be useful to have APIs like:
`_get_module_state(module)`: return the composable state if this module is managed by composable API.
`_get_module_fsdp_state(module)`: return the FSDP state if this module is managed by FSDP.

**What does this PR propose?**
1. Make `_State` out of `_composable` module so that `FullyShardedDataParallel` can inherit from it.
2. A global `_module_state_mapping: Dict[nn.Module, _State]` that keeps the mapping of all submodules (not just root module) to the state.
3. Create `_get_module_state(module)` to look up `_module_state_mapping`.
4. Create `_get_module_fsdp_state(module)` that uses `_get_module_state(module)` to get the state then verifies if the state is `_FSDPState`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89147
Approved by: https://github.com/awgu
2022-12-13 23:58:01 +00:00
Shen Li
7bd284495a Add non-reentrant checkpoint to composable APIs (#90015)
Differential Revision: [D41661027](https://our.internmc.facebook.com/intern/diff/D41661027)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90015
Approved by: https://github.com/zhaojuanmao
2022-12-01 23:05:55 +00:00
Shen Li
d9b6e41da9 Add composable activation checkpointing (#87664)
This is a composable activation checkpointing API. Unlike functional
activation checkpointing APIs, this one does not require changing
model source code. Unlike ``nn.Module`` wrapper activation checkpointing
APIs, this one does not modify model structure or fully-qualified names
either. Under the hood, it registers activation checkpointing logic as pre-
and post-forward hooks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87664
Approved by: https://github.com/zhaojuanmao
2022-10-29 17:35:58 +00:00