Commit Graph

21 Commits

Author SHA1 Message Date
Maggie Moss
f414aa8e0d Add pyrefly suppressions (3/n) (#164588)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: uncomment lines in the pyrefly.toml file
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/bb31574ac8a59893c9cf52189e67bb2d

after:

 0 errors (1,970 ignored)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164588
Approved by: https://github.com/oulgen
2025-10-03 22:03:03 +00:00
zhxchen17
ffd58293f7 [dynamo] Guard serialization for FUNCTORCH_STACK_MATCH (#152616)
Make Functorch interpreters serializable most of the time, so that we can save the guards on functorch states.

## Test Cases:

0. torch.compile() without functorch layers present. Guard should fail with any layer being pushed.
1. torch.compile() nested in vmap.
2. torch.compile() nested in grad.
3. torch.compile() nested in jvp + vmap
4. torch.compile() nested functionalize
5. torch.compile() nested in vmap + grad

Differential Revision: [D74008787](https://our.internmc.facebook.com/intern/diff/D74008787/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152616
Approved by: https://github.com/zou3519
ghstack dependencies: #152615
2025-05-05 18:05:56 +00:00
Aaron Orenstein
78bff1e8c1 PEP585 update - torch/_functorch (#145139)
See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145139
Approved by: https://github.com/bobrenjc93
2025-01-19 07:06:10 +00:00
Xuehai Pan
e7eeee473c [BE][Easy][14/19] enforce style for empty lines in import segments in torch/_[a-c]*/ and torch/_[e-h]*/ and torch/_[j-z]*/ (#129765)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129765
Approved by: https://github.com/ezyang
2024-07-31 10:42:50 +00:00
Aaron Orenstein
ea614fb2b1 Flip default value for mypy disallow_untyped_defs [2/11] (#127839)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127839
Approved by: https://github.com/oulgen
2024-06-08 18:23:08 +00:00
Xuehai Pan
93e249969b [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261)
Remove useless parentheses in `raise` statements if the exception type is raised with no argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261
Approved by: https://github.com/albanD
2024-04-17 19:29:34 +00:00
Xuehai Pan
73f0ecc1ac [BE] UFMT directory torch/_functorch (#123723)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123723
Approved by: https://github.com/Skylion007
2024-04-12 08:04:51 +00:00
Guilherme Leobas
32f9453c2a [dynamo] Emit FUNCTORCH_STACK_MATCH guard in vmap(compile(f)) case (#122786)
Fixes: #122201

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122786
Approved by: https://github.com/zou3519
2024-04-05 15:04:16 +00:00
Edward Z. Yang
05bbcae5bb Refactor functorch meta conversion (#122202)
At a high level, the goal of this refactor was to make it so that `MetaConverter.__call__` has a straightforward code structure in three steps: (1) check if we support doing meta conversion, (2) describe the tensor into MetaTensorDesc, (3) call `meta_tensor` on MetaTensorDesc. However, this is not so easy to do, because there is a big pile of special cases for functional tensor inside `__call__`.

The primarily complication is handling the ambient functionalization state: specifically, the functorch dynamic layer stack and the Python functionalization dispatch. The old code demands that meta tensor conversion happen with this state disabled. But I discovered that when I reconstruct functorch tensors it demands that the functorch layers be active; in fact a batch tensor will have a pointer to the internal functorch layer.

I had some discussion with Richard Zou about what code structure here makes sense. In particular, one of the goals of the refactor here is that I can inflate MetaTensorDesc from an entirely different process, which may not have all of the functorch layers activated at the time we do reconstruction. So it seems to me that we should make it explicit in MetaTensorDesc that there was some functorch layer active at the time the functorch tensor was serialized, so that we could potentially know we need to reconstruct these layers on the other side. This is NOT implemented yet, but there's some notes about how potentially it could proceed. But the important thing here is we SHOULD disable everything when we run `meta_tensor`, and internally be responsible for restoring the stack. Actually, the necessary infra bits in functorch don't exist to do this, so I added some simple implementations in pyfunctorch.py.

The rest is splitting up the manipulations on tensor (we do things like sync the real tensor before describing it; Describer is responsible for this now) and I also tried to simplify the not supported condition, based on my best understanding of what the old thicket of conditions was doing. You may notice that the internal meta_tensor handling of functional tensor is inconsistent with surrounding code: this is because I *exactly* replicated the old reconstruction behavior; a further refactor would be to rationalize this.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122202
Approved by: https://github.com/zou3519
2024-03-25 20:47:21 +00:00
Guilherme Leobas
4eaa000acc Teach dynamo about torch.func.jvp (#119926)
List of changes:
- Replace JVP_NESTING by torch._C._functorch.maybe_current_level()
- Remove all increment nesting functions from wrap_fx_proxy_cls
- fwAD.make_dual receives the dual_level as keyword argument
- Add jvp_increment_nesting, set_fwd_grad_enabled and dual_level context managers to dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119926
Approved by: https://github.com/zou3519
2024-03-22 20:25:47 +00:00
PyTorch MergeBot
0696db8202 Revert "Teach dynamo about torch.func.jvp (#119926)"
This reverts commit 17489784b6.

Reverted https://github.com/pytorch/pytorch/pull/119926 on behalf of https://github.com/peterbell10 due to broken mac jobs on main ([comment](https://github.com/pytorch/pytorch/pull/119926#issuecomment-2010327997))
2024-03-20 18:34:43 +00:00
Guilherme Leobas
17489784b6 Teach dynamo about torch.func.jvp (#119926)
List of changes:
- Replace JVP_NESTING by torch._C._functorch.maybe_current_level()
- Remove all increment nesting functions from wrap_fx_proxy_cls
- fwAD.make_dual receives the dual_level as keyword argument
- Add jvp_increment_nesting, set_fwd_grad_enabled and dual_level context managers to dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119926
Approved by: https://github.com/zou3519
2024-03-20 13:09:19 +00:00
PyTorch MergeBot
36e5c1dcab Revert "Teach dynamo about torch.func.jvp (#119926)"
This reverts commit edd04b7c16.

Reverted https://github.com/pytorch/pytorch/pull/119926 on behalf of https://github.com/jeanschmidt due to lots of breakages in pull jobs, checking if reverting this one will help ([comment](https://github.com/pytorch/pytorch/pull/119926#issuecomment-2007915919))
2024-03-19 18:59:46 +00:00
Guilherme Leobas
edd04b7c16 Teach dynamo about torch.func.jvp (#119926)
List of changes:
- Replace JVP_NESTING by torch._C._functorch.maybe_current_level()
- Remove all increment nesting functions from wrap_fx_proxy_cls
- fwAD.make_dual receives the dual_level as keyword argument
- Add jvp_increment_nesting, set_fwd_grad_enabled and dual_level context managers to dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119926
Approved by: https://github.com/zou3519
2024-03-19 13:06:42 +00:00
Guilherme Leobas
3319dbcd23 Update vmap guard to avoid recompilations (#119061)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119061
Approved by: https://github.com/zou3519
2024-02-13 20:50:23 +00:00
Aaron Gokaslan
dfe484a3b3 [BE]: Bugfix functorch and some generic typing improvements (#101337)
Fixes some typing bugs found with newer versions of mypy

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101337
Approved by: https://github.com/ezyang
2023-05-14 14:20:56 +00:00
Tugsbayasgalan (Tugsuu) Manlaibaatar
76a3869fc6 Support functionalization on torch.cond (#89966)
This PR adds functionalization path for torch.cond. As it is the first pass, we only functionalize for very restrictive use cases. We explicitly restrict following:

- Output of each branch aliasing input
- In-place mutation on inputs given to each branch

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89966
Approved by: https://github.com/zou3519
2022-12-22 22:01:47 +00:00
Richard Zou
ffa37c9fca Add VmapInterpreter.randomness (in pyfunctorch) provide it in info object (#90789)
This PR:
- adds VmapInterpreter.randomness. This returns the randomness option
the user provided in vmap(..., randomness=...)
- adds randomness in the info object passed to the vmap staticmethod of
autograd.Function. This is so that the user can handle random operations
on their own terms (if randomness="error", and if the autograd.Function
has random operations, then it is the user's responsiblity to raise an
error).

Test Plan:
- updated unittest
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90789
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-12-17 00:43:43 +00:00
Richard Zou
4809e838c1 functorch.jvp support for autograd.Function (#90077)
This PR adds functorch.jvp support for autograd.Function. It does so by
adding a jvp rule for custom_function_call.

For a regular PyTorch operation (like at::sin), the VariableType kernel:
- re-dispatches to at::sin
- calls the jvp rule for at::sin

The jvp rule for custom_function_call does just that. It constructs a
new autograd.Function (because the above logic already exists). Inside
the forward, it re-dispatches to custom_function_call. In the jvp rule,
it just calls whatever the jvp rule is supposed to be.

Since this logic is really close to the custom_function_call_grad, I
just put them together.

Test Plan:
- added jvp rules to the autograd.Function in autograd_function_db
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90077
Approved by: https://github.com/albanD, https://github.com/soulitzer
2022-12-14 16:20:53 +00:00
Richard Zou
3049d99027 autograd.Function supports vmap staticmethod (#90037)
This PR adds a `vmap` staticmethod to autograd.Function and a
corresponding vmap kernel for custom_function_call. These two items mean
that autograd.Function with a vmap staticmethod can be used with vmap.

```py
class NumpyMul(torch.autograd.Function)
    staticmethod
    def forward(x, y):
        return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)

    staticmethod
    def setup_context(ctx, outputs, x, y):
        ctx.save_for_backward(x, y)

    staticmethod
    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors
        gx = None
        if isinstance(x, torch.Tensor) and x.requires_grad:
            gx = NumpyMul.apply(grad_output, y)
        gy = None
        if isinstance(y, torch.Tensor) and y.requires_grad:
            gy = NumpyMul.apply(grad_output, x)
        return gx, gy

    staticmethod
    def vmap(info, in_dims, x, y):
        x_bdim, y_bdim = in_dims
        x = x.movedim(x_bdim, -1) if x_bdim else x.unsqueeze(-1)
        y = y.movedim(y_bdim, -1) if y_bdim else y.unsqueeze(-1)
        result = NumpyMul.apply(x, y)
        result = result.movedim(-1, 0)
        return result, 0
```

API Spec
- the staticmethod takes two arguments (info, in_dims) as well as the
unexpanded inputs (x, y).
- If we think about it as `vmap(info, in_dims, *args)`, `in_dims` is a
pytree with the same tree structure as args. It has None if the arg is
not being vmapped over and an integer vmapped dimension index if it is.
- `info` is an object with metadata about the vmap. It currently has one
field, `info.batch_size`. In the future we can extend this by adding
things like the randomness information.
- If there is a single vmap going on, (x, y) are NOT BatchedTensors,
they've already been unpacked.
- We expect the user to return a `(outputs, out_dims)` tuple. `out_dims`
must "broadcast" to the same pytree structure as `outputs`.

Semantics
- vmap(NumpyMul.apply)(x) will apply the vmap staticmethod if there is
one and will never actually run NumpyMul.forward.
- In order for the autograd.Function to support nested vmap (e.g.,
`vmap(vmap(NumpyMul.apply))(x)`, then the vmap staticmethod must call
into operations that vmap understands (i.e. PyTorch operators or more
autograd.Function).

At a high level, this PR:
- adds a vmap rule for custom_function_call

Testing
- Added some tests for in_dims and info
- Added vmap staticmethod to most of the autograd.Function in
autograd_function_db and sent them through functorch's vmap-related
OpInfo tests

Future
- Better error messages if the user gets the return contract wrong. I
didn't include them in this PR because it might involve a refactor of
some of the existing code in functorch/_src/vmap.py that will add
~200LOC to the PR, but LMK if you'd prefer it here.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90037
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-12-13 14:14:02 +00:00
Richard Zou
3bc327993f PyDispatcher integration with functorch (#88785)
This PR teaches PyDispatcher and PyOperator about functorch transforms.
It is important that PyDispatcher/PyOperator dispatch with functorch
transforms, because this is our plan for higher-order operators
(operators that accept functions as arguments). Examples of these
include:
- functorch transforms over the existing cond operator (control flow)
- autograd.Function support for functorch (which I am working towards),
- AOTDispatcher (should be a higher order operator)

Concretely, the problem with teaching PyDispatcher/PyOperator about
functorch is that the stack-based dispatching logic (DynamicLayerStack)
is hidden inside the fallbacks for two dispatch keys
(DynamicLayer{Front, Back}). PyDispatcher doesn't know about C++ boxed
fallbacks, our plan on record for that is that we need to reimplement
all of them in Python (but can call helper functions in C++ to make our
lives easier).

Instead of exposing all of what DynamicLayer{Front, Back} do to python,
this PR takes the approach of re-implementing part of the stack-based
dispatching in Python. The motivation is that this is more sane and
follows what the "ideal" implementation of functorch would have been:
- each transform should be a "mode"
- there should be no TLS dispatch key set hackery. functorch needs to do
this hackery today to re-use VariableType implementations.

This PR:
- exposes the DynamicLayerStack to Python
- The DynamicLayerStack is a stack of Interpreters.
These get exposed to Python as well.
- Interpreters can run operations (Interpreter.process) or lower them to
the next interpreter in the stack (Interpreter.lower)
- To use a PyOperator with functorch transforms, a developer needs to
register a rule for each transform (vmap, grad, jvp, ...).
- The PyOperator API is NOT user-facing. Things like autograd.Function
support for functorch will end up going through the autograd.Function
API.

Question for reviewers:
- Does this design make sense?
- I'm trying to split up the "functorch support for autograd.Function"
work into logical pieces. Would it be better if I didn't? (the full
thing is a bit long - 1000-2000 LOC).

Test Plan:
- new tests that construct PyOperator and compose them with functorch
transforms
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88785
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-11-16 00:46:59 +00:00