Commit Graph

17 Commits

Author SHA1 Message Date
Xuehai Pan
0c450f4504 [functorch] fix potential race condition while loading vmap decomposition library (#113520)
There can be a potential race condition while loading the `vmap` decomposition library in multi-threading programs.

This PR adds a thread lock to avoid the case of registering the kernel multiple times.

```python
import threading
from torch._functorch.vmap import lazy_load_decompositions

threads = []
for i in range(10000):
    thread = threading.Thread(target=lazy_load_decompositions)
    threads.append(thread)
for thread in threads:
    thread.start()
for thread in threads:
    thread.join()
```

```text
RuntimeError: This is not allowed since there's already a kernel registered from python overriding mse_loss_backward's behavior for FuncTorchBatched dispatch key and aten namespace.
    VMAP_DECOMPOSITIONS_LIB.impl(decomp, decomposition_table[decomp])
RuntimeError: This is not allowed since there's already a kernel registered from python overriding mse_loss_backward's behavior for FuncTorchBatched dispatch key and aten namespace.
RuntimeError: This is not allowed since there's already a kernel registered from python overriding mse_loss_backward's behavior for FuncTorchBatched dispatch key and aten namespace.
RuntimeError: This is not allowed since there's already a kernel registered from python overriding mse_loss_backward's behavior for FuncTorchBatched dispatch key and aten namespace.
RuntimeError: This is not allowed since there's already a kernel registered from python overriding mse_loss_backward's behavior for FuncTorchBatched dispatch key and aten namespace.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113520
Approved by: https://github.com/zou3519
2023-11-20 19:50:54 +00:00
Xuehai Pan
a7a0955790 [pytree][BE] reorganize imports and format code style and update type hints (#112268)
Reland PR:

- #112109

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112268
Approved by: https://github.com/Skylion007
2023-10-28 16:30:24 +00:00
Kazuaki Ishizaki
6d7744ca46 Fix typo under torch/_functorch directory (#111067)
This PR fixes typo the the of comments and exception messages in files under `torch/_functorch` directory.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111067
Approved by: https://github.com/Skylion007
2023-10-11 23:09:36 +00:00
kshitij12345
cce2c52b0b [pt2] support vmap (#101707)
Teach dynamo about `vmap`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101707
Approved by: https://github.com/zou3519
2023-08-09 03:39:33 +00:00
Kshiteej K
a899333ffc fix: nll_loss batch rule with negative ignore_idx (#106118)
We use python decompositions instead of writing our own for batching rules.

Fixes https://github.com/pytorch/pytorch/issues/105736

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106118
Approved by: https://github.com/lezcano, https://github.com/zou3519
2023-08-04 07:43:02 +00:00
yhl48
07c02b9e92 Add vmap support for smooth_l1_loss_backward (#99429)
Follow-up of #98357
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99429
Approved by: https://github.com/kshitij12345, https://github.com/zou3519
2023-04-28 10:58:07 +00:00
Li-Huai (Allan) Lin
6f181aae7c [vmap] Register decomposition for huber_loss_backward (#99236)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99236
Approved by: https://github.com/kshitij12345
2023-04-16 18:50:45 +00:00
Nikita Shulga
b2f3ff6183 [Py3.11] Remove skip logic from vmap and forward_ad (#91825)
Depends on https://github.com/pytorch/pytorch/pull/91805

Fixes https://github.com/pytorch/pytorch/issues/85506
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91825
Approved by: https://github.com/albanD
2023-01-25 22:40:56 +00:00
Sean Ross-Ross
fb3d9f39cc update vmap to accept nones (#91644)
* Fixes https://github.com/pytorch/functorch/issues/1082
* Fixes https://github.com/pytorch/functorch/issues/439

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91644
Approved by: https://github.com/kshitij12345, https://github.com/Chillee
2023-01-20 18:25:22 +00:00
kshitij12345
4437d0d161 [functorch] vmap: chunk_size support (#91157)
Ref: https://github.com/pytorch/functorch/issues/680

We introduce a kwarg `chunk_size` in vmap.

Also, we leverage most of the code from `chunk_vmap` (except for chunking the input based on `chunk_size`)

Benchmarks from https://github.com/pytorch/functorch/pull/774 apply.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91157
Approved by: https://github.com/zou3519
2022-12-22 19:45:45 +00:00
Richard Zou
fb2e1878cb [torch.func] alias torch.func.vmap as torch.vmap (#91026)
This PR also redirects torch.vmap to torch.func.vmap instead of the old
vmap prototype.

Test Plan:
- tests
- view docs preview
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91026
Approved by: https://github.com/albanD, https://github.com/samdow
2022-12-21 20:51:49 +00:00
Richard Zou
31981d0139 [generate_vmap_rule] add restore_vmap helper function (#90963)
As seen in
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit

`restore_vmap` is a private helper function. It is vmap but has the
following
differences:
- instead of returning outputs, it returns an (outputs, out_dims) tuple.
  out_dims is a pytree of shape shape as outputs and contains Optional[int]
  specifying where the vmapped dimension, if it exists, is in the
  corresponding output.
- does no validation on in_dims or inputs (vmap expects at least one
  Tensor to be vmapped).
  restore_vmap allows for no inputs to have the vmap dimension
- does no validation on outputs (vmap expects only Tensor outputs)
  restore_vmap allows for return of arbitrary outputs (not just
  Tensors)

Test Plan:
- added some simple test to test restore_vmap
- I am OK with restore_vmap not being a part of vmap right now -- the
implementation of vmap rarely changes and it is a bit difficult to
refactor vmap in a way that restore_vmap is a subroutine.

Other questions:
- Bikeshedding the `restore_vmap` name
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90963
Approved by: https://github.com/samdow, https://github.com/soulitzer
2022-12-21 00:34:41 +00:00
Richard Zou
41846e205e [torch.func] Setup torch.func, populate it with all transforms (#91016)
This PR sets up torch.func and populates it with the following APIs:
- grad
- grad_and_value
- vjp
- jvp
- jacrev
- jacfwd
- hessian
- functionalize
- vmap

It also renames all instances of `functorch` in the APIs for those docs
to `torch.func`.

We rewrite the `__module__` fields on some of the above APIs so that the
APIs fit PyTorch's public api definition.
- For an API to be public, it must have a `__module__` that points to a
  public PyTorch submodule. However, `torch._functorch.eager_transforms`
  is not public due to the leading underscore.
- The solution is to rewrite `__module__` to point to where the API is
  exposed (torch.func). This is what both Numpy and JAX do for their
  APIs.
- h/t pmeier in
  https://github.com/pytorch/pytorch/issues/90284#issuecomment-1348595246
  for idea and code
- The helper function, `exposed_in`, is confined to
  torch._functorch/utils for now because we're not completely sure if
  this should be the long-term solution.

Implication for functorch.* APIs:
- functorch.grad is the same object as torch.func.grad
- this means that the functorch.grad docstring is actually the
  torch.func.grad docstring and will refer to torch.func instead of
  functorch.
- This isn't really a problem since the plan on record is to deprecate
  functorch in favor of torch.func. We can fix these if we really want,
  but I'm not sure if a solution is worth maintaining.

Test Plan:
- view docs preview

Future:
- vmap should actually just be torch.vmap. This requires an extra step
  where I need to test internal callsites, so, I'm separating it into a
  different PR.
- make_fx should be in torch.func to be consistent with `import
  functorch`. This one is a bit more of a headache to deal with w.r.t.
  public api, so going to deal with it separately.
- beef up func.rst with everything else currently on the functorch
  documention website. func.rst is currently just an empty shell.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91016
Approved by: https://github.com/samdow
2022-12-20 00:00:52 +00:00
Richard Zou
cad1ce6158 Stop using :attr: in functorch docs (#91015)
We're using :attr: wrong. :attr: refers to an attribute of a Python
object, not the parameter to a function:
- https://www.sphinx-doc.org/en/master/usage/restructuredtext/domains.html#role-py-attr

This leads to some weird things when moving to torch.func: sphinx
decides to link torch.func for :attr:`func`

Test Plan:
- docs preview.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91015
Approved by: https://github.com/samdow
2022-12-20 00:00:52 +00:00
Richard Zou
4068c5467d [Reland] Move functorch/_src to torch/_functorch (#88756) (#90091)
This will be the last disruptive functorch internals change.

Why are we moving these files?
- As a part of rationalizing functorch we are moving the code in
functorch/_src to torch/_functorch
- This is so that we can offer the functorch APIs as native PyTorch APIs
(coming soon) and resolve some internal build issues.

Why are we moving all of these files at once?
- It's better to break developers all at once rather than many times

Test Plan:
- wait for tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90091
Approved by: https://github.com/anijain2305, https://github.com/ezyang
2022-12-03 14:17:15 +00:00
PyTorch MergeBot
218d9c6e09 Revert "Move functorch/_src to torch/_functorch (#88756)"
This reverts commit 52bc5c1cfe.

Reverted https://github.com/pytorch/pytorch/pull/88756 on behalf of https://github.com/clee2000 due to broke imports in tests 52bc5c1cfe https://github.com/pytorch/pytorch/actions/runs/3574742513/jobs/6010814968 probably a landrace
2022-11-29 17:17:11 +00:00
Richard Zou
52bc5c1cfe Move functorch/_src to torch/_functorch (#88756)
This will be the last disruptive functorch internals change.

Why are we moving these files?
- As a part of rationalizing functorch we are moving the code in
functorch/_src to torch/_functorch
- This is so that we can offer the functorch APIs as native PyTorch APIs
(coming soon) and resolve some internal build issues.

Why are we moving all of these files at once?
- It's better to break developers all at once rather than many times

Test Plan:
- wait for tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88756
Approved by: https://github.com/ezyang
2022-11-29 13:55:42 +00:00