Commit Graph

20 Commits

Author SHA1 Message Date
Aaron Orenstein
00ffeca1b1 PEP585 update - torch/distributed (#145164)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164
Approved by: https://github.com/bobrenjc93
2025-01-21 04:23:29 +00:00
PyTorch MergeBot
6374332d33 Revert "PEP585 update - torch/distributed (#145164)"
This reverts commit 6cb186e279.

Reverted https://github.com/pytorch/pytorch/pull/145164 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing an inductor test ([comment](https://github.com/pytorch/pytorch/pull/145164#issuecomment-2602875679))
2025-01-20 16:46:46 +00:00
Aaron Orenstein
6cb186e279 PEP585 update - torch/distributed (#145164)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164
Approved by: https://github.com/bobrenjc93
2025-01-20 00:19:01 +00:00
bobrenjc93
fbad833538 Migrate from Tuple -> tuple in test/distributed/_composable (#144254)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144254
Approved by: https://github.com/aorenste
2025-01-10 06:38:05 +00:00
Aaron Orenstein
45ef3309e3 [BE] typing for decorators (#144161)
Summary:
Untyped decorators strip annotations from the decorated items.

- _compile
- _inductor/fx_passes/post_grad
- _inductor/lowering
- _library/custom_ops
- _meta_registrations
- _ops
- _refs/nn/functional
- ao/quantization/quantizer/xnnpack_quantizer_utils
- distributed/_composable/contract
- fx/experimental/graph_gradual_typechecker
- fx/experimental/migrate_gradual_types/constraint_generator
- optim/optimizer
- signal/windows/windows
- testing/_internal/common_device_type
- torch/_inductor/decomposition
- utils/flop_counter

Test Plan: unit tests

Differential Revision: D62302684

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144161
Approved by: https://github.com/Skylion007, https://github.com/albanD
2025-01-04 16:40:09 +00:00
Oguz Ulgen
72d2dba992 Add None return type to init (#132335)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132335
Approved by: https://github.com/albanD
2024-08-01 15:26:45 +00:00
Aaron Orenstein
5a0068cc69 [BE] mypy: disallow untyped decorators (#131428)
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations.

Step 1 - Enable the error and override in all the offending files.

#131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428
Approved by: https://github.com/justinchuby, https://github.com/oulgen
2024-07-23 21:50:55 +00:00
Sanket Jayant Purandare
d67923b955 Adding kwargs to composable AC API to enable full capabilities (#128516)
Summary:
Firstly, this does not change any existing behaviour, since all the
default values for kwargs were hardcoded into the ``_checkpoint_without_reentrant_generator`` call.

Secondly, this is needed for unlocking the full potential of composable
checkpointing making it equivalent to ``torch.utils.checkpoint.checkpoint(use_reentrant=False)``.

Finally, an added benefit is now composable checkpointing can be used under ``FakeTensorMode`` by
passing ``preserve_rng_state=False``.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128516
Approved by: https://github.com/awgu
2024-06-15 00:23:48 +00:00
Aaron Orenstein
62bcdc0ac9 Flip default value for mypy disallow_untyped_defs [4/11] (#127841)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127841
Approved by: https://github.com/oulgen
2024-06-08 18:36:48 +00:00
Rohan Varma
5d70fe0165 [Composable] Use non-reentrant generator, remove reentrant (#105176)
Removes reentrant support for the composable checkpoint, as
non-reentrant is the recommended approach and we should use this when rolling
out composable checkpoint API.

Also removes the standalone implementation for non-reentrant and instead uses
the generator from below diff to reuse the original implemenetation.

Differential Revision: [D47451375](https://our.internmc.facebook.com/intern/diff/D47451375/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105176
Approved by: https://github.com/awgu, https://github.com/fegin
2023-07-26 07:03:03 +00:00
Andrew Gu
2eea3cb19d Fix composable checkpoint(use_reentrant=True) with multi args (#103590)
The `_ModuleHookCheckpointFunction.backward()` should take in `*output_grads` instead of `output_grads`. Otherwise, we may see an error like:
```
TypeError: backward() takes 2 positional arguments but 5 were given
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103590
Approved by: https://github.com/rohan-varma, https://github.com/fduwjj, https://github.com/fegin
2023-06-14 21:53:30 +00:00
Aaron Gokaslan
3e2ea32dab [BE]: Enable ruff rule TRY302 and apply fixes (#101874)
Removes useless try statements and unreachable code.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101874
Approved by: https://github.com/malfet
2023-05-19 17:30:52 +00:00
Colin Taylor
e5496ebcac [torch] [composable] [analytics] add analytics logging to PT-D composable APIs (#95016)
Summary: as title

Test Plan: N/A

Differential Revision: D43376274

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95016
Approved by: https://github.com/awgu, https://github.com/rohan-varma, https://github.com/fegin
2023-02-17 02:49:16 +00:00
Matthew Hoffman
a26e5e21b5 Improve type hints for Module forward hooks (#92061)
Fixes #91654.

Currently, the `hook` parameters of `nn.Module.register_forward_pre_hook` and `nn.Module.register_forward_hook` are typed as `Callable[..., None]`, which 1) does not enable the validation of the signature of `hook` and 2) incorrectly restricts the return type of `hook`, which the docstrings of these methods themselves state can be non-`None`.

The typing of the first parameter of `hook` as `TypeVar("T", bound="Module")` allows the binding of `Callable` whose first parameter is a subclass of `Module`.

---

Here are some examples of:
1. forward hooks and pre-hook hooks being accepted by mypy according to the new type hints
2. mypy throwing errors d.t. incorrect `hook` signatures
3. false negatives of pre-hooks being accepted as forward hooks
4. false negatives of hooks with kwargs being accepted irrespective of the value provided for `with_kwargs`

```python
from typing import Any, Dict, Tuple

import torch
from torch import nn

def forward_pre_hook(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
) -> None:
    ...

def forward_pre_hook_return_input(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
) -> Tuple[torch.Tensor, ...]:
    ...

def forward_pre_hook_with_kwargs(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    kwargs: Dict[str, Any],
) -> None:
    ...

def forward_pre_hook_with_kwargs_return_input(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    kwargs: Dict[str, Any],
) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]:
    ...

def forward_hook(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    output: torch.Tensor,
) -> None:
    ...

def forward_hook_return_output(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    output: torch.Tensor,
) -> torch.Tensor:
    ...

def forward_hook_with_kwargs(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    kwargs: Dict[str, Any],
    output: torch.Tensor,
) -> None:
    ...

def forward_hook_with_kwargs_return_output(
    module: nn.Linear,
    args: Tuple[torch.Tensor, ...],
    kwargs: Dict[str, Any],
    output: torch.Tensor,
) -> torch.Tensor:
    ...

model = nn.Module()

# OK
model.register_forward_pre_hook(forward_pre_hook)
model.register_forward_pre_hook(forward_pre_hook_return_input)
model.register_forward_pre_hook(forward_pre_hook_with_kwargs, with_kwargs=True)
model.register_forward_pre_hook(forward_pre_hook_with_kwargs_return_input, with_kwargs=True)

model.register_forward_hook(forward_hook)
model.register_forward_hook(forward_hook_return_output)
model.register_forward_hook(forward_hook_with_kwargs, with_kwargs=True)
model.register_forward_hook(forward_hook_with_kwargs_return_output, with_kwargs=True)

# mypy(error): [arg-type]
model.register_forward_pre_hook(forward_hook)
model.register_forward_pre_hook(forward_hook_return_output)
model.register_forward_pre_hook(forward_hook_with_kwargs)
model.register_forward_pre_hook(forward_hook_with_kwargs_return_output)

model.register_forward_hook(forward_pre_hook)
model.register_forward_hook(forward_pre_hook_return_input)

# false negatives
model.register_forward_hook(forward_pre_hook_with_kwargs)
model.register_forward_hook(forward_pre_hook_with_kwargs_return_input)

model.register_forward_pre_hook(forward_pre_hook_with_kwargs, with_kwargs=False)
model.register_forward_pre_hook(forward_pre_hook_with_kwargs_return_input, with_kwargs=False)
...
```

---

Though it is not functional as of mypy 0.991, the ideal typing of these methods would use [`typing.Literal`](https://mypy.readthedocs.io/en/stable/literal_types.html#literal-types):

```python
T = TypeVar("T", bound="Module")

class Module:

    @overload
    def register_forward_hook(
        self,
        hook: Callable[[T, Tuple[Any, ...], Any], Optional[Any]],
        *,
        prepend: bool = ...,
        with_kwargs: Literal[False] = ...,
    ) -> RemovableHandle:
        ...

    @overload
    def register_forward_hook(
        self,
        hook: Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],
        *,
        prepend: bool = ...,
        with_kwargs: Literal[True] = ...,
    ) -> RemovableHandle:
        ...

    def register_forward_hook(...):
        ...

```

which would:

1. validate the signature of `hook` according to the corresponding literal value provided for `with_kwargs` (and fix the false negative examples above)
2. implicitly define the [fallback `bool` signature](https://github.com/python/mypy/issues/6113#issuecomment-1266186192) e.g. to handle if a non-literal is provided for `with_kwargs`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92061
Approved by: https://github.com/albanD
2023-01-13 15:45:42 +00:00
joncrall
ad782ff7df Enable xdoctest runner in CI for real this time (#83816)
Builds on #83317 and enables running the doctests. Just need to figure out what is causing the failures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83816
Approved by: https://github.com/ezyang, https://github.com/malfet
2022-12-29 05:32:42 +00:00
Chien-Chin Huang
d08e3d2304 [Composable API] Apply ufmt to _composable and the corresponding test folders (#91255)
This PR apply ufmt to format `_composable` related code. This is a request from https://github.com/pytorch/pytorch/pull/91234 to separate formatting changes as a new PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91255
Approved by: https://github.com/awgu
2022-12-23 16:08:27 +00:00
Shen Li
a0554261a1 Restore RNG states for composable reentrant activation checkpointing (#91265)
This allows ops like randperm to behave the same during re-computation.

Differential Revision: [D42196758](https://our.internmc.facebook.com/intern/diff/D42196758/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91265
Approved by: https://github.com/awgu
2022-12-22 03:15:55 +00:00
Chien-Chin Huang
d52f121dba [Composable API]Common _State parent class for composable and wrapper FSDP (#89147)
**Why this PR?**

For the composable APIs implementation, sometimes the internal APIs may not have the application (FSDP, DDP) root module but only the local module. One example is the state_dict/optimizer_state_dict implementation of FSDP. These APIs  are designed to start with the root module of the model. It is tricky for these APIs to tell whether a random submodule is managed by either DDP or FSDP.

It will be useful to have APIs like:
`_get_module_state(module)`: return the composable state if this module is managed by composable API.
`_get_module_fsdp_state(module)`: return the FSDP state if this module is managed by FSDP.

**What does this PR propose?**
1. Make `_State` out of `_composable` module so that `FullyShardedDataParallel` can inherit from it.
2. A global `_module_state_mapping: Dict[nn.Module, _State]` that keeps the mapping of all submodules (not just root module) to the state.
3. Create `_get_module_state(module)` to look up `_module_state_mapping`.
4. Create `_get_module_fsdp_state(module)` that uses `_get_module_state(module)` to get the state then verifies if the state is `_FSDPState`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89147
Approved by: https://github.com/awgu
2022-12-13 23:58:01 +00:00
Shen Li
7bd284495a Add non-reentrant checkpoint to composable APIs (#90015)
Differential Revision: [D41661027](https://our.internmc.facebook.com/intern/diff/D41661027)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90015
Approved by: https://github.com/zhaojuanmao
2022-12-01 23:05:55 +00:00
Shen Li
d9b6e41da9 Add composable activation checkpointing (#87664)
This is a composable activation checkpointing API. Unlike functional
activation checkpointing APIs, this one does not require changing
model source code. Unlike ``nn.Module`` wrapper activation checkpointing
APIs, this one does not modify model structure or fully-qualified names
either. Under the hood, it registers activation checkpointing logic as pre-
and post-forward hooks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87664
Approved by: https://github.com/zhaojuanmao
2022-10-29 17:35:58 +00:00