There can be a potential race condition while loading the `vmap` decomposition library in multi-threading programs.
This PR adds a thread lock to avoid the case of registering the kernel multiple times.
```python
import threading
from torch._functorch.vmap import lazy_load_decompositions
threads = []
for i in range(10000):
thread = threading.Thread(target=lazy_load_decompositions)
threads.append(thread)
for thread in threads:
thread.start()
for thread in threads:
thread.join()
```
```text
RuntimeError: This is not allowed since there's already a kernel registered from python overriding mse_loss_backward's behavior for FuncTorchBatched dispatch key and aten namespace.
VMAP_DECOMPOSITIONS_LIB.impl(decomp, decomposition_table[decomp])
RuntimeError: This is not allowed since there's already a kernel registered from python overriding mse_loss_backward's behavior for FuncTorchBatched dispatch key and aten namespace.
RuntimeError: This is not allowed since there's already a kernel registered from python overriding mse_loss_backward's behavior for FuncTorchBatched dispatch key and aten namespace.
RuntimeError: This is not allowed since there's already a kernel registered from python overriding mse_loss_backward's behavior for FuncTorchBatched dispatch key and aten namespace.
RuntimeError: This is not allowed since there's already a kernel registered from python overriding mse_loss_backward's behavior for FuncTorchBatched dispatch key and aten namespace.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113520
Approved by: https://github.com/zou3519
As seen in
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit
`restore_vmap` is a private helper function. It is vmap but has the
following
differences:
- instead of returning outputs, it returns an (outputs, out_dims) tuple.
out_dims is a pytree of shape shape as outputs and contains Optional[int]
specifying where the vmapped dimension, if it exists, is in the
corresponding output.
- does no validation on in_dims or inputs (vmap expects at least one
Tensor to be vmapped).
restore_vmap allows for no inputs to have the vmap dimension
- does no validation on outputs (vmap expects only Tensor outputs)
restore_vmap allows for return of arbitrary outputs (not just
Tensors)
Test Plan:
- added some simple test to test restore_vmap
- I am OK with restore_vmap not being a part of vmap right now -- the
implementation of vmap rarely changes and it is a bit difficult to
refactor vmap in a way that restore_vmap is a subroutine.
Other questions:
- Bikeshedding the `restore_vmap` name
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90963
Approved by: https://github.com/samdow, https://github.com/soulitzer
This PR sets up torch.func and populates it with the following APIs:
- grad
- grad_and_value
- vjp
- jvp
- jacrev
- jacfwd
- hessian
- functionalize
- vmap
It also renames all instances of `functorch` in the APIs for those docs
to `torch.func`.
We rewrite the `__module__` fields on some of the above APIs so that the
APIs fit PyTorch's public api definition.
- For an API to be public, it must have a `__module__` that points to a
public PyTorch submodule. However, `torch._functorch.eager_transforms`
is not public due to the leading underscore.
- The solution is to rewrite `__module__` to point to where the API is
exposed (torch.func). This is what both Numpy and JAX do for their
APIs.
- h/t pmeier in
https://github.com/pytorch/pytorch/issues/90284#issuecomment-1348595246
for idea and code
- The helper function, `exposed_in`, is confined to
torch._functorch/utils for now because we're not completely sure if
this should be the long-term solution.
Implication for functorch.* APIs:
- functorch.grad is the same object as torch.func.grad
- this means that the functorch.grad docstring is actually the
torch.func.grad docstring and will refer to torch.func instead of
functorch.
- This isn't really a problem since the plan on record is to deprecate
functorch in favor of torch.func. We can fix these if we really want,
but I'm not sure if a solution is worth maintaining.
Test Plan:
- view docs preview
Future:
- vmap should actually just be torch.vmap. This requires an extra step
where I need to test internal callsites, so, I'm separating it into a
different PR.
- make_fx should be in torch.func to be consistent with `import
functorch`. This one is a bit more of a headache to deal with w.r.t.
public api, so going to deal with it separately.
- beef up func.rst with everything else currently on the functorch
documention website. func.rst is currently just an empty shell.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91016
Approved by: https://github.com/samdow
This will be the last disruptive functorch internals change.
Why are we moving these files?
- As a part of rationalizing functorch we are moving the code in
functorch/_src to torch/_functorch
- This is so that we can offer the functorch APIs as native PyTorch APIs
(coming soon) and resolve some internal build issues.
Why are we moving all of these files at once?
- It's better to break developers all at once rather than many times
Test Plan:
- wait for tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90091
Approved by: https://github.com/anijain2305, https://github.com/ezyang
This will be the last disruptive functorch internals change.
Why are we moving these files?
- As a part of rationalizing functorch we are moving the code in
functorch/_src to torch/_functorch
- This is so that we can offer the functorch APIs as native PyTorch APIs
(coming soon) and resolve some internal build issues.
Why are we moving all of these files at once?
- It's better to break developers all at once rather than many times
Test Plan:
- wait for tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88756
Approved by: https://github.com/ezyang