Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40576
`out_dims` specifies where in the output tensors the vmapped dimension
should appear. We implement this by simply creating a view with the
batch dimension moved to the desired position.
`out_dims` must either:
- be int (use the same value for all outputs)
- be Tuple[int] (so the user specifies one out_dim per output).
(See the vmap docstring for what we advertise out_dims to do).
I also renamed `TestVmap` to `TestVmapAPI` to make it clearer that we
are testing the API here and not specific operators (which will go into
their own test class).
Test Plan: - `pytest test/test_vmap.py -v`
Differential Revision: D22288086
Pulled By: zou3519
fbshipit-source-id: c8666cb1a0e22c54473d8045477e14c2089167cf
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40575
This provides some more context for the next ~2 PRs that will implement
the `out_dims` and `in_dims` functionality. I will probably add more to
it later (things I think we should add: examples (maybe in a dedicated
docs page), specific examples of things vmap cannot handle).
Test Plan:
- Code reading for now. When we are ready to add vmap to master documentation,
I'll build the docs and fix any formatting problems.
Differential Revision: D22288085
Pulled By: zou3519
fbshipit-source-id: 6e28d7bd524242395160c20270159b4b121d6789
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40518
I overlooked this in the initial vmap frontend api PR. Right now we
want to restrict vmap to taking in functions that only return Tensors.
A function that only return tensors can look like one of the following:
```
def fn1(x):
...
return y
def fn2(x):
...
return y, z
```
fn1 returns a Tensor, while fn2 returns a tuple of Tensors. So we add a
check that the output of the function passed to vmap returns either a
single tensor or a tuple of tensors.
NB: These checks allow passing a function that returns a tuple with a
single-element tensor from vmap. That seems OK to me.
Test Plan: - `python test/test_vmap.py -v`
Differential Revision: D22216166
Pulled By: zou3519
fbshipit-source-id: a92215e9c26f6138db6b10ba81ab0c2c2c030929
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40172
This PR introduces the initial vmap frontend API. It has the following
limitations that we can resolve in the future:
- the inputs must be a flat list of tensors
- the outputs must be a flat list of tensors
- in_dims = 0 (so we always vmap over dim 0 of input tensors)
- out_dims = 0 (so the returned tensors have their vmap dim appear at
dim 0)
- Coverage limited to operations that have batching rules implemented
(torch.mul, torch.sum, torch.expand).
There are some other semantic limitations (like not being able to handle
mutation, aside from pytorch operations that perform mutation) that will
be documented in the future.
I wanted to introduce the API before adding a slow fallback for the
coverage so that we can test future batching rules (and coverage) via
the python API to avoid verbosity in C++-land.
The way vmap works is that `vmap(func)(inputs)` wraps all Tensor inputs
to be batched in BatchedTensors, sends those into func, and then unwraps
the output BatchedTensors. Operations on BatchedTensors perform the batched
operations that the user is asking for. When performing nested vmaps,
each nested vmap adds a batch dimension upon entry and removes a batch
dimension on exit.
Coming up in the near future:
- Support for non-zero in_dims and out_dims
- docstring for vmap
- slow fallback for operators that do not have a batching rule
implemented.
Test Plan: - `pytest test/test_vmap.py -v`
Differential Revision: D22102076
Pulled By: zou3519
fbshipit-source-id: b119f0a8a3a3b1717c92dbbd180dfb1618295563