Commit Graph

65 Commits

Author SHA1 Message Date
Richard Zou
7d10298067 Implement Tensor.to batching rule (#43206)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43206

The batching rule is the same as the unary pointwise batching rules:
given a BatchedTensor, we unwrap it, call Tensor.to, and then re-wrap
it.

Test Plan: - `pytest test/test_vmap.py -v -k`

Reviewed By: ezyang

Differential Revision: D23189053

Pulled By: zou3519

fbshipit-source-id: 51b4e41b1cd34bd082082ec4fff3c643002edbaf
2020-08-19 10:54:26 -07:00
Richard Zou
37252e8f00 Implement batching rules for some unary ops (#43059)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43059

This PR implements batching rules for some unary ops. In particular, it
implements the batching rules for the unary ops that take a single
tensor as input (and nothing else).

The batching rule for a unary op is:
(1) grab the physical tensor straight out of the BatchedTensor
(2) call the unary op
(3) rewrap the physical tensor in a BatchedTensor

Test Plan: - new tests `pytest test/test_vmap.py -v -k "Operators"`

Reviewed By: ezyang

Differential Revision: D23132277

Pulled By: zou3519

fbshipit-source-id: 24b9d7535338207531d767155cdefd2c373ada77
2020-08-17 13:38:10 -07:00
Richard Zou
768c2a8c25 vmap: fixed to work with functools.partial (#43028)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43028

There was a bug where we always tried to grab the `__name__` attribute of
the function passed in by the user. Not all Callables have the
`__name__` attribute, an example being a Callable produced by
functools.partial.

This PR modifies the error-checking code to use `repr` if `__name__` is
not available. Furthermore, it moves the "get the name of this function"
functionality to the actual error sites as an optimization so we don't
spend time trying to compute `__repr__` for the Callable if there is no
error.

Test Plan: - `pytest test/test_vmap.py -v`, added new tests.

Reviewed By: yf225

Differential Revision: D23130235

Pulled By: zou3519

fbshipit-source-id: 937f3640cc4d759bf6fa38b600161f5387a54dcf
2020-08-17 13:36:49 -07:00
Richard Zou
bda0007620 Improve calling backward() and grad() inside vmap error messages (#42876)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42876

Previously, the error messages were pretty bad. This PR adds nice
error messages for the following cases:
- user attempts to call .backward() inside vmap for any reason
whatsoever
- user attempts to call autograd.grad(outputs, inputs, grad_outputs),
where outputs or inputs is being vmapped over (so they are
BatchedTensors).

The case we do support is calling autograd.grad(outputs, inputs,
grad_outputs) where `grad_outputs` is being vmapped over. This is the
case for batched gradient support (e.g., user passes in a batched
grad_output).

Test Plan: - new tests: `pytest test/test_vmap.py -v`

Reviewed By: ezyang

Differential Revision: D23059836

Pulled By: zou3519

fbshipit-source-id: 2fd4e3fd93f558e67e2f0941b18f0d00d8ab439f
2020-08-12 10:05:31 -07:00
Richard Zou
e8f4b04d9a vmap: temporarily disable support for random functions (#42617)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42617

While we figure out the random plan, I want to initially disable
support for random operations. This is because there is an ambiguity in
what randomness means. For example,

```
tensor = torch.zeros(B0, 1)
vmap(lambda t: t.normal_())(tensor)
```

in the above example, should tensor[0] and tensor[1] be equal (i.e.,
use the same random seed), or should they be different?

The mechanism for disabling random support is as follows:
- We add a new dispatch key called VmapMode
- Whenever we're inside vmap, we enable VmapMode for all tensors.
This is done via at::VmapMode::increment_nesting and
at::VmapMode::decrement_nesting.
- DispatchKey::VmapMode's fallback kernel is the fallthrough kernel.
- We register kernels that raise errors for all random functions on
DispatchKey::VmapMode. This way, whenever someone calls a random
function on any tensor (not just BatchedTensors) inside of a vmap block,
an error gets thrown.

Test Plan: - pytest test/test_vmap.py -v -k "Operators"

Reviewed By: ezyang

Differential Revision: D22954840

Pulled By: zou3519

fbshipit-source-id: cb8d71062d4087e10cbf408f74b1a9dff81a226d
2020-08-11 07:19:51 -07:00
Richard Zou
8f67c7a624 BatchedTensor fallback: extended to support ops with multiple Tensor returns (#42628)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42628

This PR extends the BatchedTensor fallback to support operators with
multiple Tensor returns. If an operator has multiple returns, we stack
shards of each return to create the full outputs.

Test Plan:
- `pytest test/test_vmap.py -v`. Added a new test for an operator with
multiple returns (torch.var_mean).

Reviewed By: izdeby

Differential Revision: D22957095

Pulled By: zou3519

fbshipit-source-id: 5c0ec3bf51283cc4493b432bcfed1acf5509e662
2020-08-10 17:42:03 -07:00
Richard Zou
f3e8fff0d2 Batching rules for: chunk, split, unbind (#42480)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42480

These are grouped together because they all return a tuple of multiple
tensors.

This PR implements batching rules for chunk, split, and unbind. It also
updates the testing logic. Previously, reference_vmap was not able to
handle multiple outputs, now, it does.

Test Plan: - `pytest test/test_vmap.py -v -k "Operators"`

Reviewed By: ezyang

Differential Revision: D22905401

Pulled By: zou3519

fbshipit-source-id: 9963c943d035e9035c866be74dbdf7ab1989f8c4
2020-08-04 08:33:43 -07:00
Richard Zou
f1d7f001b9 Batching rules for: torch.movedim, torch.narrow, Tensor.unfold (#42474)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42474

Test Plan: - `pytest test/test_vmap.py -v -k "Operators"`

Reviewed By: ezyang

Differential Revision: D22903513

Pulled By: zou3519

fbshipit-source-id: 06b3fb0c7d12b9a045c73a5c5a4f4e3207e07b02
2020-08-04 08:33:41 -07:00
Richard Zou
01cd613e7e Batching rules for: T, view, view_as, reshape, reshape_as (#42458)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42458

Test Plan: - `pytest test/test_vmap.py -v -k "Operators"`

Reviewed By: ezyang

Differential Revision: D22898715

Pulled By: zou3519

fbshipit-source-id: 47f374962697dcae1d5aec80a41085679d016f92
2020-08-04 08:31:33 -07:00
Richard Zou
4cdbe5c495 Implement batching rules for some view ops (#42248)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42248

Including:
- torch.diagonal
- torch.t
- torch.select
- Tensor.expand_as
- Tensor slicing.

Please let me know in the future if it would be easier to review these
separately (I put five operators into this PR because each
implementation is relatively simple).

Test Plan:
- new tests in `test/test_vmap.py`.
- I would like to have a more structured/automated way of testing but
my previous attempts at making something resulted in something very
complicated.

Reviewed By: ezyang

Differential Revision: D22846273

Pulled By: zou3519

fbshipit-source-id: 8e45ebe11174512110faf1ee0fdc317a25e8b7ac
2020-08-03 08:01:48 -07:00
Richard Zou
2f8d5b68fa vmap fallback kernel (#41943)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41943

If an operator doesn't have a batching rule implemented then we fallback
to this implementation. The fallback only works on out-of-place operators
that return only tensors with new memory. (e.g., no in-place operators,
no view operations).

The fallback effectively takes all of the BatchedTensors in `stack`,
slices them, and runs `op` on all of the corresponding slices to produce slices
of the outputs. The output slices then get `torch.stack`ed to create the
final returns.

The performance of the fallback is not very good because it introduces
an extra copy from stacking the sliced outputs. Because of this, we prefer
to write batching rules for operators whenever possible.

In the future, I'd like to disable the fallback kernel for random
functions until we have a better random story for vmap. I will probably
add a blocklist of operators to support that.

Test Plan: - `pytest test/test_vmap.py -v`

Reviewed By: ezyang

Differential Revision: D22764103

Pulled By: zou3519

fbshipit-source-id: b235833f7f27e11fb76a8513357ac3ca286a638b
2020-08-03 07:59:33 -07:00
Richard Zou
5d1d8a58b8 Enable in_dims for vmap frontend api (#40717)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40717

`in_dims` specifies which dimension of the input tensors should be
vmapped over. One can also specify `None` as an `in_dim` for a particular
input to indicate that we do not map over said input.

We implement `in_dims` by creating a BatchedTensor with BatchDim equal
to said `in_dim`. Most of this PR is error checking. `in_dims` must
satisfy the following:
- `in_dim` can be either an int or a Tuple[Optional[int]]. If it is an
int, we use it to mean the `in_dim` for every input.
- If `in_dims` is not-None at some index `idx`, then the input at index
`idx` MUST be a tensor (vmap can only map over tensors).

jax supports something more generalized: their `in_dims` can match the
structure of the `inputs` to the function (i.e., it is a nested python
data structure matching the data structure of `inputs` specifying where
in `inputs` the Tensors to be mapped are and what their map dims should
be). We don't have the infrastruture yet so we only support `int` or a
flat tuple for `in_dims`.

Test Plan: - `pytest test/test_vmap.py -v`

Differential Revision: D22397914

Pulled By: zou3519

fbshipit-source-id: 56d2e14be8b6024e4cde2729eff384da305b4ea3
2020-07-06 19:14:43 -07:00
Richard Zou
a6a31bcd47 Enable out_dims for vmap frontend API (#40576)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40576

`out_dims` specifies where in the output tensors the vmapped dimension
should appear. We implement this by simply creating a view with the
batch dimension moved to the desired position.

`out_dims` must either:
- be int (use the same value for all outputs)
- be Tuple[int] (so the user specifies one out_dim per output).
(See the vmap docstring for what we advertise out_dims to do).

I also renamed `TestVmap` to `TestVmapAPI` to make it clearer that we
are testing the API here and not specific operators (which will go into
their own test class).

Test Plan: - `pytest test/test_vmap.py -v`

Differential Revision: D22288086

Pulled By: zou3519

fbshipit-source-id: c8666cb1a0e22c54473d8045477e14c2089167cf
2020-06-30 08:20:39 -07:00
Richard Zou
c362138f43 Disallow passing functions that don't return Tensors to vmap (#40518)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40518

I overlooked this in the initial vmap frontend api PR. Right now we
want to restrict vmap to taking in functions that only return Tensors.
A function that only return tensors can look like one of the following:
```
def fn1(x):
    ...
    return y

def fn2(x):
    ...
    return y, z
```
fn1 returns a Tensor, while fn2 returns a tuple of Tensors. So we add a
check that the output of the function passed to vmap returns either a
single tensor or a tuple of tensors.

NB: These checks allow passing a function that returns a tuple with a
single-element tensor from vmap. That seems OK to me.

Test Plan: - `python test/test_vmap.py -v`

Differential Revision: D22216166

Pulled By: zou3519

fbshipit-source-id: a92215e9c26f6138db6b10ba81ab0c2c2c030929
2020-06-25 08:54:05 -07:00
Richard Zou
727463a727 Initial vmap frontend API (#40172)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40172

This PR introduces the initial vmap frontend API. It has the following
limitations that we can resolve in the future:
- the inputs must be a flat list of tensors
- the outputs must be a flat list of tensors
- in_dims = 0 (so we always vmap over dim 0 of input tensors)
- out_dims = 0 (so the returned tensors have their vmap dim appear at
dim 0)
- Coverage limited to operations that have batching rules implemented
(torch.mul, torch.sum, torch.expand).

There are some other semantic limitations (like not being able to handle
mutation, aside from pytorch operations that perform mutation) that will
be documented in the future.

I wanted to introduce the API before adding a slow fallback for the
coverage so that we can test future batching rules (and coverage) via
the python API to avoid verbosity in C++-land.

The way vmap works is that `vmap(func)(inputs)` wraps all Tensor inputs
to be batched in BatchedTensors, sends those into func, and then unwraps
the output BatchedTensors. Operations on BatchedTensors perform the batched
operations that the user is asking for. When performing nested vmaps,
each nested vmap adds a batch dimension upon entry and removes a batch
dimension on exit.

Coming up in the near future:
- Support for non-zero in_dims and out_dims
- docstring for vmap
- slow fallback for operators that do not have a batching rule
implemented.

Test Plan: - `pytest test/test_vmap.py -v`

Differential Revision: D22102076

Pulled By: zou3519

fbshipit-source-id: b119f0a8a3a3b1717c92dbbd180dfb1618295563
2020-06-24 08:14:24 -07:00