Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43206
The batching rule is the same as the unary pointwise batching rules:
given a BatchedTensor, we unwrap it, call Tensor.to, and then re-wrap
it.
Test Plan: - `pytest test/test_vmap.py -v -k`
Reviewed By: ezyang
Differential Revision: D23189053
Pulled By: zou3519
fbshipit-source-id: 51b4e41b1cd34bd082082ec4fff3c643002edbaf
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43059
This PR implements batching rules for some unary ops. In particular, it
implements the batching rules for the unary ops that take a single
tensor as input (and nothing else).
The batching rule for a unary op is:
(1) grab the physical tensor straight out of the BatchedTensor
(2) call the unary op
(3) rewrap the physical tensor in a BatchedTensor
Test Plan: - new tests `pytest test/test_vmap.py -v -k "Operators"`
Reviewed By: ezyang
Differential Revision: D23132277
Pulled By: zou3519
fbshipit-source-id: 24b9d7535338207531d767155cdefd2c373ada77
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43028
There was a bug where we always tried to grab the `__name__` attribute of
the function passed in by the user. Not all Callables have the
`__name__` attribute, an example being a Callable produced by
functools.partial.
This PR modifies the error-checking code to use `repr` if `__name__` is
not available. Furthermore, it moves the "get the name of this function"
functionality to the actual error sites as an optimization so we don't
spend time trying to compute `__repr__` for the Callable if there is no
error.
Test Plan: - `pytest test/test_vmap.py -v`, added new tests.
Reviewed By: yf225
Differential Revision: D23130235
Pulled By: zou3519
fbshipit-source-id: 937f3640cc4d759bf6fa38b600161f5387a54dcf
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42876
Previously, the error messages were pretty bad. This PR adds nice
error messages for the following cases:
- user attempts to call .backward() inside vmap for any reason
whatsoever
- user attempts to call autograd.grad(outputs, inputs, grad_outputs),
where outputs or inputs is being vmapped over (so they are
BatchedTensors).
The case we do support is calling autograd.grad(outputs, inputs,
grad_outputs) where `grad_outputs` is being vmapped over. This is the
case for batched gradient support (e.g., user passes in a batched
grad_output).
Test Plan: - new tests: `pytest test/test_vmap.py -v`
Reviewed By: ezyang
Differential Revision: D23059836
Pulled By: zou3519
fbshipit-source-id: 2fd4e3fd93f558e67e2f0941b18f0d00d8ab439f
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42617
While we figure out the random plan, I want to initially disable
support for random operations. This is because there is an ambiguity in
what randomness means. For example,
```
tensor = torch.zeros(B0, 1)
vmap(lambda t: t.normal_())(tensor)
```
in the above example, should tensor[0] and tensor[1] be equal (i.e.,
use the same random seed), or should they be different?
The mechanism for disabling random support is as follows:
- We add a new dispatch key called VmapMode
- Whenever we're inside vmap, we enable VmapMode for all tensors.
This is done via at::VmapMode::increment_nesting and
at::VmapMode::decrement_nesting.
- DispatchKey::VmapMode's fallback kernel is the fallthrough kernel.
- We register kernels that raise errors for all random functions on
DispatchKey::VmapMode. This way, whenever someone calls a random
function on any tensor (not just BatchedTensors) inside of a vmap block,
an error gets thrown.
Test Plan: - pytest test/test_vmap.py -v -k "Operators"
Reviewed By: ezyang
Differential Revision: D22954840
Pulled By: zou3519
fbshipit-source-id: cb8d71062d4087e10cbf408f74b1a9dff81a226d
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42628
This PR extends the BatchedTensor fallback to support operators with
multiple Tensor returns. If an operator has multiple returns, we stack
shards of each return to create the full outputs.
Test Plan:
- `pytest test/test_vmap.py -v`. Added a new test for an operator with
multiple returns (torch.var_mean).
Reviewed By: izdeby
Differential Revision: D22957095
Pulled By: zou3519
fbshipit-source-id: 5c0ec3bf51283cc4493b432bcfed1acf5509e662
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42480
These are grouped together because they all return a tuple of multiple
tensors.
This PR implements batching rules for chunk, split, and unbind. It also
updates the testing logic. Previously, reference_vmap was not able to
handle multiple outputs, now, it does.
Test Plan: - `pytest test/test_vmap.py -v -k "Operators"`
Reviewed By: ezyang
Differential Revision: D22905401
Pulled By: zou3519
fbshipit-source-id: 9963c943d035e9035c866be74dbdf7ab1989f8c4
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42248
Including:
- torch.diagonal
- torch.t
- torch.select
- Tensor.expand_as
- Tensor slicing.
Please let me know in the future if it would be easier to review these
separately (I put five operators into this PR because each
implementation is relatively simple).
Test Plan:
- new tests in `test/test_vmap.py`.
- I would like to have a more structured/automated way of testing but
my previous attempts at making something resulted in something very
complicated.
Reviewed By: ezyang
Differential Revision: D22846273
Pulled By: zou3519
fbshipit-source-id: 8e45ebe11174512110faf1ee0fdc317a25e8b7ac
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41943
If an operator doesn't have a batching rule implemented then we fallback
to this implementation. The fallback only works on out-of-place operators
that return only tensors with new memory. (e.g., no in-place operators,
no view operations).
The fallback effectively takes all of the BatchedTensors in `stack`,
slices them, and runs `op` on all of the corresponding slices to produce slices
of the outputs. The output slices then get `torch.stack`ed to create the
final returns.
The performance of the fallback is not very good because it introduces
an extra copy from stacking the sliced outputs. Because of this, we prefer
to write batching rules for operators whenever possible.
In the future, I'd like to disable the fallback kernel for random
functions until we have a better random story for vmap. I will probably
add a blocklist of operators to support that.
Test Plan: - `pytest test/test_vmap.py -v`
Reviewed By: ezyang
Differential Revision: D22764103
Pulled By: zou3519
fbshipit-source-id: b235833f7f27e11fb76a8513357ac3ca286a638b
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40717
`in_dims` specifies which dimension of the input tensors should be
vmapped over. One can also specify `None` as an `in_dim` for a particular
input to indicate that we do not map over said input.
We implement `in_dims` by creating a BatchedTensor with BatchDim equal
to said `in_dim`. Most of this PR is error checking. `in_dims` must
satisfy the following:
- `in_dim` can be either an int or a Tuple[Optional[int]]. If it is an
int, we use it to mean the `in_dim` for every input.
- If `in_dims` is not-None at some index `idx`, then the input at index
`idx` MUST be a tensor (vmap can only map over tensors).
jax supports something more generalized: their `in_dims` can match the
structure of the `inputs` to the function (i.e., it is a nested python
data structure matching the data structure of `inputs` specifying where
in `inputs` the Tensors to be mapped are and what their map dims should
be). We don't have the infrastruture yet so we only support `int` or a
flat tuple for `in_dims`.
Test Plan: - `pytest test/test_vmap.py -v`
Differential Revision: D22397914
Pulled By: zou3519
fbshipit-source-id: 56d2e14be8b6024e4cde2729eff384da305b4ea3
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40576
`out_dims` specifies where in the output tensors the vmapped dimension
should appear. We implement this by simply creating a view with the
batch dimension moved to the desired position.
`out_dims` must either:
- be int (use the same value for all outputs)
- be Tuple[int] (so the user specifies one out_dim per output).
(See the vmap docstring for what we advertise out_dims to do).
I also renamed `TestVmap` to `TestVmapAPI` to make it clearer that we
are testing the API here and not specific operators (which will go into
their own test class).
Test Plan: - `pytest test/test_vmap.py -v`
Differential Revision: D22288086
Pulled By: zou3519
fbshipit-source-id: c8666cb1a0e22c54473d8045477e14c2089167cf
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40518
I overlooked this in the initial vmap frontend api PR. Right now we
want to restrict vmap to taking in functions that only return Tensors.
A function that only return tensors can look like one of the following:
```
def fn1(x):
...
return y
def fn2(x):
...
return y, z
```
fn1 returns a Tensor, while fn2 returns a tuple of Tensors. So we add a
check that the output of the function passed to vmap returns either a
single tensor or a tuple of tensors.
NB: These checks allow passing a function that returns a tuple with a
single-element tensor from vmap. That seems OK to me.
Test Plan: - `python test/test_vmap.py -v`
Differential Revision: D22216166
Pulled By: zou3519
fbshipit-source-id: a92215e9c26f6138db6b10ba81ab0c2c2c030929
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40172
This PR introduces the initial vmap frontend API. It has the following
limitations that we can resolve in the future:
- the inputs must be a flat list of tensors
- the outputs must be a flat list of tensors
- in_dims = 0 (so we always vmap over dim 0 of input tensors)
- out_dims = 0 (so the returned tensors have their vmap dim appear at
dim 0)
- Coverage limited to operations that have batching rules implemented
(torch.mul, torch.sum, torch.expand).
There are some other semantic limitations (like not being able to handle
mutation, aside from pytorch operations that perform mutation) that will
be documented in the future.
I wanted to introduce the API before adding a slow fallback for the
coverage so that we can test future batching rules (and coverage) via
the python API to avoid verbosity in C++-land.
The way vmap works is that `vmap(func)(inputs)` wraps all Tensor inputs
to be batched in BatchedTensors, sends those into func, and then unwraps
the output BatchedTensors. Operations on BatchedTensors perform the batched
operations that the user is asking for. When performing nested vmaps,
each nested vmap adds a batch dimension upon entry and removes a batch
dimension on exit.
Coming up in the near future:
- Support for non-zero in_dims and out_dims
- docstring for vmap
- slow fallback for operators that do not have a batching rule
implemented.
Test Plan: - `pytest test/test_vmap.py -v`
Differential Revision: D22102076
Pulled By: zou3519
fbshipit-source-id: b119f0a8a3a3b1717c92dbbd180dfb1618295563