Commit Graph

4 Commits

Author SHA1 Message Date
Richard Zou
a6a31bcd47 Enable out_dims for vmap frontend API (#40576)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40576

`out_dims` specifies where in the output tensors the vmapped dimension
should appear. We implement this by simply creating a view with the
batch dimension moved to the desired position.

`out_dims` must either:
- be int (use the same value for all outputs)
- be Tuple[int] (so the user specifies one out_dim per output).
(See the vmap docstring for what we advertise out_dims to do).

I also renamed `TestVmap` to `TestVmapAPI` to make it clearer that we
are testing the API here and not specific operators (which will go into
their own test class).

Test Plan: - `pytest test/test_vmap.py -v`

Differential Revision: D22288086

Pulled By: zou3519

fbshipit-source-id: c8666cb1a0e22c54473d8045477e14c2089167cf
2020-06-30 08:20:39 -07:00
Richard Zou
2f94b7f95c Initial vmap docstring (#40575)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40575

This provides some more context for the next ~2 PRs that will implement
the `out_dims` and `in_dims` functionality. I will probably add more to
it later (things I think we should add: examples (maybe in a dedicated
docs page), specific examples of things vmap cannot handle).

Test Plan:
- Code reading for now. When we are ready to add vmap to master documentation,
I'll build the docs and fix any formatting problems.

Differential Revision: D22288085

Pulled By: zou3519

fbshipit-source-id: 6e28d7bd524242395160c20270159b4b121d6789
2020-06-30 08:18:20 -07:00
Richard Zou
c362138f43 Disallow passing functions that don't return Tensors to vmap (#40518)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40518

I overlooked this in the initial vmap frontend api PR. Right now we
want to restrict vmap to taking in functions that only return Tensors.
A function that only return tensors can look like one of the following:
```
def fn1(x):
    ...
    return y

def fn2(x):
    ...
    return y, z
```
fn1 returns a Tensor, while fn2 returns a tuple of Tensors. So we add a
check that the output of the function passed to vmap returns either a
single tensor or a tuple of tensors.

NB: These checks allow passing a function that returns a tuple with a
single-element tensor from vmap. That seems OK to me.

Test Plan: - `python test/test_vmap.py -v`

Differential Revision: D22216166

Pulled By: zou3519

fbshipit-source-id: a92215e9c26f6138db6b10ba81ab0c2c2c030929
2020-06-25 08:54:05 -07:00
Richard Zou
727463a727 Initial vmap frontend API (#40172)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40172

This PR introduces the initial vmap frontend API. It has the following
limitations that we can resolve in the future:
- the inputs must be a flat list of tensors
- the outputs must be a flat list of tensors
- in_dims = 0 (so we always vmap over dim 0 of input tensors)
- out_dims = 0 (so the returned tensors have their vmap dim appear at
dim 0)
- Coverage limited to operations that have batching rules implemented
(torch.mul, torch.sum, torch.expand).

There are some other semantic limitations (like not being able to handle
mutation, aside from pytorch operations that perform mutation) that will
be documented in the future.

I wanted to introduce the API before adding a slow fallback for the
coverage so that we can test future batching rules (and coverage) via
the python API to avoid verbosity in C++-land.

The way vmap works is that `vmap(func)(inputs)` wraps all Tensor inputs
to be batched in BatchedTensors, sends those into func, and then unwraps
the output BatchedTensors. Operations on BatchedTensors perform the batched
operations that the user is asking for. When performing nested vmaps,
each nested vmap adds a batch dimension upon entry and removes a batch
dimension on exit.

Coming up in the near future:
- Support for non-zero in_dims and out_dims
- docstring for vmap
- slow fallback for operators that do not have a batching rule
implemented.

Test Plan: - `pytest test/test_vmap.py -v`

Differential Revision: D22102076

Pulled By: zou3519

fbshipit-source-id: b119f0a8a3a3b1717c92dbbd180dfb1618295563
2020-06-24 08:14:24 -07:00