Commit Graph

57 Commits

Author SHA1 Message Date
Arindam Roy
556fc8d418 skip test_symeig if MAGMA not detected (#54526)
Summary:
Add proper way to skip test_symeig. In case MAGMA is not detected, skip the test_symeig properly.
Added skipCUDAIfNoMagma decorator.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/54526

Reviewed By: malfet

Differential Revision: D27293640

Pulled By: heitorschueroff

fbshipit-source-id: 245f86540af0e37c8795e80dc003e1ca4c08cd5b
2021-03-24 13:55:36 -07:00
iramazanli
d7b5a6faaa Revert "Revert D26733731: [pytorch][PR] Skip dispatch for `is_floatin… (#53242)
Summary:
…g_point`"

This reverts commit fbf2883d35.

Fixes #{issue number}

Pull Request resolved: https://github.com/pytorch/pytorch/pull/53242

Reviewed By: mrshenli

Differential Revision: D26896105

Pulled By: iramazanli

fbshipit-source-id: 279a6f6d4fbb7949a7ed65df848db71a9b8d44e2
2021-03-11 09:46:25 -08:00
Kyle Chen
bf5e5bf901 [ROCm] Enable test in test_linalg.py, test_optim.py and test_vmap.py … (#52818)
Summary:
Enable test in test_linalg.py, test_optim.py, and test_vmap.py for ROCm because they are passing.

Signed-off-by: Kyle Chen <kylechen@amd.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/52818

Reviewed By: H-Huang

Differential Revision: D26694091

Pulled By: mruberry

fbshipit-source-id: 285d17aa7f271f4d94b5fa9d9f6620de8a70847b
2021-03-04 02:29:45 -08:00
Richard Zou
1379842f4a Add private mechanism to toggle vmap fallback warnings (#51218)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51218

Fixes #51144.

Context
=======

Users have complained about warning spam from batched gradient
computation. This warning spam happens because warnings in C++ don't
correctly get turned into Python warnings when those warnings arise from
the autograd engine.

To work around that, this PR adds a mechanism to toggle vmap warnings.
By default, the vmap fallback will not warn when it is invoked. However,
by using `torch._C._debug_only_display_vmap_fallback_warnings(enabled)`,
one can toggle the existence of vmap fallback warnings.

This API is meant to be a private, debug-only API. The goal is to be
able to non-intrusively collect feedback from users to improve
performance on their workloads.

What this PR does
=================

This PR adds an option to toggle vmap warnings. The mechanism is
toggling a bool in ATen's global context.

There are some other minor changes:
- This PR adds a more detailed explanation of performance cliffs to the
autograd.functional.{jacobian, hessian} documentation
- A lot of the vmap tests in `test_vmap.py` rely on the fallback warning
to test the presence of the fallback. In test_vmap, I added a context
manager to toggle on the fallback warning while testing.

Alternatives
============

I listed a number of alternatives in #51144. My favorite one is having a new
"performance warnings mode" (this is currently a WIP by some folks on
the team). This PR is to mitigate the problem of warning spam before
a "performance warnings mode" gets shipped into PyTorch

Concerns
========

I am concerned that we are advertising a private API
(`torch._C._debug_only_display_vmap_fallback_warnings(enabled)`) in the
PyTorch documentation. However, I hope the naming makes it clear to
users that they should not rely on this API (and I don't think they have
any reason to rely on the API).

Test Plan
=========

Added tests in `test_vmap.py` to check:
- by default, the fallback does not warn
- we can toggle whether the fallback warns or not

Test Plan: Imported from OSS

Reviewed By: pbelevich, anjali411

Differential Revision: D26126419

Pulled By: zou3519

fbshipit-source-id: 95a97f9b40dc7334f6335a112fcdc85dc03dcc73
2021-01-28 13:05:00 -08:00
Xiong Wei
5cdc32bf1c [vmap] Add batching rules for comparisons ops (#50364)
Summary:
Related to https://github.com/pytorch/pytorch/issues/49562

This PR adds batching rules for the below comparison ops.
- torch.eq
- torch.gt
- torch.ge
- torch.le
- torch.lt
- torch.ne

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50364

Reviewed By: anjali411

Differential Revision: D25885359

Pulled By: zou3519

fbshipit-source-id: 58874f24f8d525d8fac9062186b1c9970618ff55
2021-01-12 13:00:56 -08:00
Kaiwen Wang
483670ff0f [pytorch] add threshold_backward batching for vmap (#49881)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49881

title

Test Plan: pytest test/test_vmap.py -v -k "BatchedGrad"

Reviewed By: zou3519

Differential Revision: D25711289

fbshipit-source-id: f1856193249fda70da41e36e15bc26ea7966b510
2021-01-04 12:24:05 -08:00
Erjia Guan
da790eca69 Add trace batching forward/backward rule (#49979)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49979

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D25734379

Pulled By: ejguan

fbshipit-source-id: 8f9346afaf324e7ab17bafd6ecc97eed8442fd38
2021-01-04 12:04:55 -08:00
Richard Zou
2ec3e803eb Update accumulate_grad to support vmap (#49119)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49119

I don't know how the accumulate_grad code gets hit via calling
autograd.grad, so I went through all places in accumulate_grad
that are definitely impossible to vmap through and changed them.

To support this:
- I added vmap support for Tensor::strides(). It returns the strides
that correspond to the public dimensions of the tensor (not the ones
being vmapped over).
- Changed an instance of empty_strided to new_empty_strided.
- Replaced an in-place operation in accumulate_grad.h

Test Plan:
- added a test for calling strides() inside of vmap
- added tests that exercise all of the accumulate_grad code path.
NB: I don't know why these tests exercise the code paths, but I've
verified that they do via gdb.

Suggestions for some saner test cases are very welcome.

Reviewed By: izdeby

Differential Revision: D25563543

Pulled By: zou3519

fbshipit-source-id: 05ac6c549ebd447416e6a07c263a16c90b2ef510
2020-12-16 11:30:16 -08:00
Xiong Wei
909a9060e9 [vmap] implement batching rule for fill_ and zero_ (#48516)
Summary:
Fix https://github.com/pytorch/pytorch/issues/47755

- This PR implements batching rules for in-place operators `fill_` and `zero_`.
- Testcases are added to the `test/test_vmap.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48516

Reviewed By: H-Huang

Differential Revision: D25431557

Pulled By: zou3519

fbshipit-source-id: 437b0534dc0b818fbe05f7fcfcb649aa677483dc
2020-12-10 10:59:05 -08:00
Xiong Wei
8f8738ce5c [vmap] implement batching rules for clamp, clamp_min and clamp_max (#48449)
Summary:
Fix https://github.com/pytorch/pytorch/issues/47754

- This PR implements batching rules for `clamp`, `clamp_min` and `clamp_max` operators.
- Testcases are added to `test/test_vmap.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48449

Reviewed By: ejguan

Differential Revision: D25219360

Pulled By: zou3519

fbshipit-source-id: 0b7e1b00f5553b4578f15a6cc440640e506b4918
2020-11-30 14:22:43 -08:00
Richard Zou
370310bedb batched grad for binary_cross_entropy, symeig (#48057)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48057

This PR fixes batched grad computation for:
- binary_cross_entropy (i.e., vmap through binary_cross_entropy_double_backward)
- symeig (i.e. vmap through symeig_backward)

It was previously impossible to vmap through those functions because
they use in-place operations in a vmap-incompatible way.

See note at
233192be73/aten/src/ATen/BatchedFallback.cpp (L117-L122)
for what it means for an in-place operation to be vmap-incompatible.

This PR adds a check: if the in-place operations in e.g. symeig are
vmap-incompatible and we are inside of a vmap, then we do the
out-of-place variant of the operation. Ditto for binary_cross_entropy.

This is to avoid code duplication: the alternative would be to register
the backward formula as an operator and change just those lines to be
out-of-place!

This PR also adds some general guidelines for what to do if an in-place
operation is vmap-incompatible.

General guidelines
------------------

If an in-place operation used in a backward formula is vmap-incompatible,
then as developers we have the following options:

- If the in-place operation directly followed the creation of a tensor with
  a factory function like at::zeros(...), we should replace the factory with a
  corresponding grad.new_zeros(...) call. The grad.new_zeros(...) call
  propagates the batch dims to the resulting tensor.
  For example:
    Before: at::zeros(input.sizes(), grad.options()).copy_(grad)
    After:  grad.new_zeros(input.sizes()).copy_(grad)

- If the in-place operation followed some sequence of operations, if the
  we want to be able to vmap over the backward formula as-is (this is
  usually the case for simple (<15loc) backward formulas), then use
  inplace_is_vmap_compatible to guard the operation. For example:
            c = a * b
    Before: c.mul_(grad)
    After:  c = inplace_is_vmap_compatible(c, grad) ? c.mul_(grad) : c * grad

- If we don't want to vmap directly over the backward formula (e.g., if the
  backward formula is too complicated or has a lot of vmap-incompatible
  operations, then register the backward formula as an operator and eventually
  write a batching rule for it.

Test Plan
---------
New tests

Test Plan: Imported from OSS

Reviewed By: zhangguanheng66

Differential Revision: D25069525

Pulled By: zou3519

fbshipit-source-id: e0dfeb5a812f35b7579fc6ecf7252bf31ce0d790
2020-11-19 07:59:02 -08:00
Richard Zou
05a76ed705 Batching rule for torch.squeeze(tensor) (#47632)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47632

This one is fun because we have to be careful not to squeeze out any of
the batch dims (it is the dims of the per-example tensor that are being squeezed).

Test Plan: - new tests

Reviewed By: anjali411

Differential Revision: D24859022

Pulled By: zou3519

fbshipit-source-id: 8adbd80963081efb683f62ea074a286a10da288f
2020-11-11 14:08:39 -08:00
Richard Zou
df887936a4 Fix transpose batching rule (#47628)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47628

Pytorch has a special case where scalar_tensor.transpose(0, 0) works and
returns the scalar tensor. If the following happens:
```py
>>> x = torch.randn(B0)  # the per-examples are all scalars
>>> vmap(lambda x: x.transpose(0, 0), x)
```
then we replicate this behavior

Test Plan: - new tests

Reviewed By: anjali411

Differential Revision: D24843658

Pulled By: zou3519

fbshipit-source-id: e33834122652473e34a18ca1cecf98e8a3b84bc1
2020-11-11 14:08:37 -08:00
Richard Zou
f6ff6478cf Make kwargs argument optional in _batched_grad_test (#47625)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47625

kwargs is {} most of the time so this PR makes it optional. Note that it
is bad practice for {} to be a default argument; we work around this by
using None as the default and handling it accordingly.

Test Plan
- `pytest test/test_vmap.py -v`

Test Plan: Imported from OSS

Reviewed By: Chillee

Differential Revision: D24842571

Pulled By: zou3519

fbshipit-source-id: a46b0c6d5240addbe3b231b8268cdc67708fa9e0
2020-11-11 14:08:35 -08:00
Richard Zou
fc24d0656a Tensor.contiguous, Tensor.is_contiguous batch rule (#47621)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47621

Followup to #47365.

is_contiguous on BatchedTensorImpl is implemented as:
- Whenever one creates a BatchedTensorImpl, we cache the strides of the
per-examples, just like how we cache the sizes of the per-examples.
- With the cached strides, we use TensorImpl::refresh_contiguous() to
compute if the tensor is contiguous or not.
- is_contiguous checks the `is_contiguous_` flag that
refresh_contiguous() populates.

Both contiguous and is_contiguous only support torch.contiguous_format.
I'm not sure what the semantics should be for other memory formats; they
are also rank dependent (e.g., channels_last tensor must have 4
dimensions) which makes this a bit tricky.

Test Plan: - new tests

Reviewed By: Chillee, anjali411

Differential Revision: D24840975

Pulled By: zou3519

fbshipit-source-id: 4d86dbf11e2eec45f3f08300ae3f2d79615bb99d
2020-11-11 14:06:05 -08:00
Richard Zou
57dcb04239 Batched gradient support for view+inplace operations (#47227)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47227

Motivation
----------
We would like to compute batched gradients for view+inplace operations.
This most notably shows up in internal implementation of operations.
For example, many view backward functions (SelectBackward, DiagonalBackward)
are implemented with view+inplace, so to support vectorized hessian
computation for e.g. torch.select and torch.diagonal we would need a
way to handle or workaround view+inplace.

Approach
--------
view+inplace creates a CopySlices node and transmute view backward nodes
into an AsStrided node. For example,

```
leaf = torch.randn(4, 5, requires_grad=True)
base = leaf * leaf
view = base[0]
view.cos_()
```

base.grad_fn is CopySlices and view.grad_fn is AsStridedBackward.

To support vmap over CopySlices and AsStridedBackward:
- We use `new_empty_strided` instead of `empty_strided` in CopySlices
so that the batch dims get propagated
- We use `new_zeros` inside AsStridedBackward so that the batch dims get
propagated.

Test Plan
---------
- New tests. When we get closer to having most operations support batched
grad computation via vmap, I'd like to add it as an option to gradcheck
and turn it on for our tests.

Test Plan: Imported from OSS

Reviewed By: kwanmacher, glaringlee

Differential Revision: D24741687

Pulled By: zou3519

fbshipit-source-id: 8210064f782a0a7a193752029a4340e505ffb5d8
2020-11-10 07:38:02 -08:00
Richard Zou
ead86b2419 Add batching rule for torch.clone(tensor, torch.contiguous_format) (#47365)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47365

I wanted to avoid defining vmap behavior over contiguous_format for as
long as possible. This is potentially ambiguous, consider the following:
```
>>> x = torch.randn(3, B0, 5)
>>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1,
out_dims=1)(x)
>>> y[:,0].is_contiguous()  # ??
```
There are two possible ways to interpret this operation (if we choose to
allow it to succeed):
1. Each per-sample becomes contiguous, so y[:,0] is contiguous.
2. The output of vmap is contiguous (so y is contiguous, but y[:,0] is
not)

(1) makes more sense because vmap operates on a per-sample level.
This makes sense when combined with the vmap fallback:
- there are places in the codebase where we perform .contiguous() and
then pass the result to an operator `op` that only accepts contiguous
inputs.
- If we vmap over such code and don't have a batching rule implemented for
`op`, then we want the per-samples to be contiguous so that
when `op` goes through the vmap fallback, it receives contiguous
per-samples.

(1) is the approach we've selected for this PR.

Motivation
----------
To vmap over CopySlices, we have to vmap over a clone(contiguous_format)
call:
e4bc785dd5/torch/csrc/autograd/functions/tensor.cpp (L93)

Alternatives
------------
- Implementing (2) is difficult in the current design because vmap is
allowed to move batch dimensions to the front of the tensor. We would
need some global information about the in_dims and out_dims passed to
vmap.
- We could also error out if someone calls clone(contiguous_format) and
the batch dims are not at the front. This would resolve the ambiguity at
the cost of limiting what vmap can do.

Future Work
-----------
- Add to a "vmap gotchas" page the behavior of contiguous_format.
- Implement is_contiguous, Tensor.contiguous() with the same semantics.
Those currently error out.

Test Plan
---------
- new tests

Test Plan: Imported from OSS

Reviewed By: bdhirsh

Differential Revision: D24741683

Pulled By: zou3519

fbshipit-source-id: 3ef5ded1b646855f41d39dcefe81129176de8a70
2020-11-09 11:36:48 -08:00
Richard Zou
7bc8fdb6d7 as_strided batching rule (#47364)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47364

This PR adds a batching rule for as_strided. `as_strided` is a really weird
operation and I hope that users don't use it very much.

Motivation
----------
The motivation for adding a batching rule for as_strided is for
batched gradient computation.

AsStridedBackward appears in PyTorch when handling view+in-place
operations and calls `as_strided`. AsStridedBackward calls as_strided on
a fresh tensor with storage_offset equal to 0. We would like to be able
to vmap through the backward graph of view+in-place operations to
for batched gradient computation, especially because internally we have
a number of functions that are implemented as a view+in-place.

Alternatives
------------
If we think that as_strided is too crazy to have a batching rule, we
could either:
- have a flag that controls the autograd view+in-place
behavior
- require that the input tensor's storage offset must be equal to 0
to make it easier to reason about.

I think the batching rule makes sense, so I didn't pursue the
alternatives.

The batching rule
-----------------
```
y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs)
```
The result of the above should be "equivalent" to:
- Assume that each x has storage offset equal to xs.storage_offset()
(call that S).
- Calling as_strided with (sizes, sizes, offset + x[i].storage_offset() - S) on each x.

More concretely,
this returns a view on `xs`, such that each y[i] has:
- sizes: `sizes`
- strides: `strides`
- storage_offset: offset + i * x.stride(batch_dim)

Why the behavior can be weird
-----------------------------
The behavior of the batching rule may be different from actually running
as_strided in a for-loop because `as_strided` takes in `offset` as a
"absolute offset". As an example, consider

```
>>> x = torch.tensor([0., 1., 2., 3., 4.])
>>> z = [x[i].as_strided([1], [1], 0) for i in range(5)]
```
Each z[i] is actually the same view on x (z[i] == torch.tensor([0.]))!
However, we consider the above for-loop comprehension to be a user error:
a user should have written the following if they wanted to use as_strided
in a per-sample way:
```
>>> z = [x[i].as_strided([1], [1], 0 + x[i].storage_offset()) for i in range(5)]
```

Test Plan
---------
- Added some tests that compare vmap+as_strided to vmap+(the equivalent operator)

Test Plan: Imported from OSS

Reviewed By: bdhirsh

Differential Revision: D24741685

Pulled By: zou3519

fbshipit-source-id: c1429caff43bfa33661a80bffc0daf2c0eea5564
2020-11-09 11:36:44 -08:00
Richard Zou
b80da89891 Batching rule for Tensor.new_empty_strided (#47226)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47226

The batching rule is a little weird because it's not immediately obvious
what the strides of the result should be. If
tensor.new_empty_strided(size, stride) is called inside vmap and
`tensor` is being vmapped over, the result is a physical tensor with:
- size `[batch_shape] + size`
- strides `[S0, S1, ..., Sn] + stride` such that the
S0...Sn are part of a contiguous subspace and Sn is equal to the size of
the storage of `torch.empty_strided(size, stride)`.

I refactored some of the logic that computes the storage size for
`torch.empty_strided(size, stride)` into a helper function
`native::storage_size_for` and use it in the batching rule.

Test Plan: - New tests in test/test_vmap.py

Reviewed By: ejguan

Differential Revision: D24741690

Pulled By: zou3519

fbshipit-source-id: f09b5578e923470d456d50348d86687a03b598d2
2020-11-09 08:31:04 -08:00
Richard Zou
9c8f40516f Batched grad for advanced indexing (index) (#47223)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47223

This PR enables batched gradient computation for advanced indexing.
Previously, the backward formula was writing parts of the grad tensori
in-place to zeros_like(self). Since grad is a BatchedTensor and self is
not a BatchedTensor, this is not possible.

To solve the problem, we instead create a new tensor with
`grad.new_zeros` and then write to that in-place. This new tensor will
have the same batchedness as the `grad` tensor.

To prevent regressions (the autograd codegen special cases zeros_like
to avoid saving the `self` tensor for backward), we teach the autograd
codegen how to save `self.options()`.

Test Plan:
- new tests
- run old indexing tests

Reviewed By: ejguan

Differential Revision: D24741684

Pulled By: zou3519

fbshipit-source-id: e267999dc079f4fe58c3f0bdf5c263f1879dca92
2020-11-05 18:25:33 -08:00
Richard Zou
e40a563050 Fix sum batching rule, add simple clone batching rule (#47189)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47189

PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail
and instead returns a new copy of the original scalar_tensor. If we
end up vmapping over per-example scalar tensors, e.g.,
```
>>> x = torch.randn(B0)  # the per-examples are all scalars
>>> vmap(partial(torch.sum, dim=0), x)
```
then we should replicate the behavior of sum(scalar_tensor, dim=0) by
returning a clone of the input tensor.

This PR also adds a batching rule for clone(Tensor, MemoryFormat). The
batching rule:
- unwraps the BatchedTensor, calls clone(), and rewraps the
BatchedTensor if MemoryFormat is torch.preserve_format (which is the
default).
- errors out with an NYI for all other memory formats, including
torch.contiguous_format. There are some weird semantics for memory
layouts with vmap that I need to go and figure out. Those are noted in
the comments for `clone_batching_rule`

Test Plan: - new tests

Reviewed By: ejguan

Differential Revision: D24741689

Pulled By: zou3519

fbshipit-source-id: e640344b4e4aa8c0d2dbacc5c49901f4c33c6613
2020-11-05 07:38:43 -08:00
Richard Zou
9a9529aa84 Batching rules for complex view functions (#47188)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47188

Includes batching rules for:
- torch.real, torch.imag, torch.view_as_real, and torch.view_as_complex

Test Plan: - new tests

Reviewed By: ejguan

Differential Revision: D24741686

Pulled By: zou3519

fbshipit-source-id: c143bab9bb5ebbcd8529e12af7c117cbebd4447e
2020-11-05 07:37:15 -08:00
Richard Zou
02dc52f25b vmap fallback: gracefully error out when vmap over dim of size 0 (#46846)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46846

Previously, this would crash with a floating point error. If the user vmaps
over a dimension of size 0, ideally we would return a tensor with a
batch dim of size 0 and the correct output shape. However, this isn't
possible without a shape-checking API. This PR changes the vmap fallback
to error out gracefully if it sees vmap occuring over a dimension of
size 0.

If we want to support vmapping over dimension of size 0 for a specific
op, then the guidance is to implement a batching rule for that op that
handles 0-sized dims.

Test Plan: - new test

Reviewed By: ezyang

Differential Revision: D24539315

Pulled By: zou3519

fbshipit-source-id: a19c049b46512d77c084cfee145720de8971f658
2020-10-26 15:32:22 -07:00
Richard Zou
74d81080a0 Use new_zeros in evenly_distribute_backward (#46674)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46674

Summary
-------

This adds batched gradient support (i.e., vmap through the gradient
formulas) for Tensor.max(), Tensor.min(), Tensor.median()
that have evenly_distribute_backward as their backward formula.

Previously, the plan was to register incompatible gradient formulas as
backward operators (see #44052). However, it turns out that we can just use
`new_zeros` to get around some incompatible gradient formulas (see next
section for discussion).

Context: the vmap+inplace problem
---------------------------------

A lot of backwards functions are incompatible with BatchedTensor due to
using in-place operations. Sometimes we can allow the in-place
operations, but other times we can't. For example, consider select_backward:

```
Tensor select_backward(const Tensor& grad, IntArrayRef input_sizes,
                       int64_t dim, int64_t index) {
  auto grad_input = at::zeros(input_sizes, grad.options());
  grad_input.select(dim, index).copy_(grad);
  return grad_input;
}
```
and consider the following code:

```
x = torch.randn(5, requires_grad=True)
def select_grad(v):
  torch.autograd.grad(x[0], x, v)

vs = torch.randn(B0)
batched_grads = vmap(select_grad)(vs)
```

For the batched gradient use case, grad is a BatchedTensor.
The physical version of grad has size (B0,).
However, select_backward creates a grad_input of shape (5), and
tries to copy grad to a slice of it.

Up until now, the proposal to handle this has been to register these
backward formulas as operators so that vmap doesn’t actually see the
`copy_` calls (see #44052). However, it turns out we can actually just
use `new_zeros` to construct a new Tensor that has the same
"batched-ness" as grad:
```
auto grad_input = grad.new_zeros(input_sizes);
grad_input.select(dim, index).copy_(grad);
```
We should use this for simple backward functions. For more complicated
backward functions where this solution doesn't work, we should register
those as operators.

Alternatives
------------
Option 2: Register `evenly_distribute_backward` as an operator and have the
vmap fallback run it in a loop.
- This requires more LOC changes.
- Furthermore, we'd have to write an efficient batching rule for
`evenly_distribute_backward` in the future.
- If we use `new_zeros` instead, we don't need to write an efficient
batching rule for `evenly_distribute_backward` as long as the
constituents of `evenly_distributed_backward` have efficient batching rules.

Option 3: Have factory functions perform differently if they are called
inside vmap.
- For example, `at::zeros(3, 5)` could return a Tensor of shape
`(B0, B1, 3, 5)` if we are vmapping over two dimensions with size B0 and B1.
This requires maintaining some global and/or thread-local state about
the size of the dims being vmapped over which can be tricky.

And more...

Future
------
- I will undo some of the work I’ve done in the past to move backward
functions to being operators (#44052, #44408). The simpler backward
functions (like select backward) can just use Tensor.new_zeros.
I apologize for the thrashing.
- Include a NOTE about the vmap+inplace problem somewhere in the
codebase. I don't have a good idea of where to put it at the moment.

Test Plan
---------
- New tests

Test Plan: Imported from OSS

Reviewed By: gchanan

Differential Revision: D24456781

Pulled By: zou3519

fbshipit-source-id: 9c6c8ee2cb1a4e25afd779bdf0bdf5ab76b9bc20
2020-10-23 14:29:40 -07:00
Richard Zou
aa828bf084 Support undefined grads in vmap fallback (#46671)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46671

Previously, the vmap fallback would choke whenever it saw an undefined
tensor. For each sample in a batch, the fallback runs an operator
and then stacks together outputs to get the actual output.
Undefined tensors can occur as outputs while computing batched gradients
with vmap.

This PR updates the vmap fallback to handle undefined tensors which can
appear in backward formulas:
- if for each sample in a batch the output was undefined, then the vmap
fallback returns an undefined tensor
- if for each sample in a batch the output is defined, then the vmap
fallback stacks together the defined tensors
- if for some samples in a batch the output is defined/undefined, then
we error out.

Test Plan: - new tests

Reviewed By: ezyang

Differential Revision: D24454909

Pulled By: zou3519

fbshipit-source-id: d225382fd17881f23c9833323b68834cfef351f3
2020-10-23 14:26:50 -07:00
Richard Zou
18d80501a6 Batching rules for: new_zeros, new_empty (#46606)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46606

Note that new_empty uses `m.impl_UNBOXED` because the operator doesn't
go through the c10 dispatcher due to #43572.

Test Plan: - new tests

Reviewed By: ezyang

Differential Revision: D24428106

Pulled By: zou3519

fbshipit-source-id: 5e10f87a967fb27c9c3065f3d5b577db61aeb20e
2020-10-22 11:40:51 -07:00
Richard Zou
1c8d0d8cc9 Allow vmap to accept nested python data structures as inputs (#46289)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46289

Previously, vmap had the restriction that any Tensors in the inputs must
not be a part of a nested python collection. This PR relaxes that
restriction. We can also do the same thing for vmap outputs, but I'll
leave that for future work

The mechanism behind vmap is to convert any Tensor inputs (that have
been specified via in_dims) into BatchedTensor. Using a pytree
implementation, that logic becomes:
- flatten inputs
- broadcast in_dims to inputs and unflatten it
- use the flat inputs and flat in_dims to construct BatchedTensors
- unflatten the BatchedTensors into the same structure as the original
inputs.
- Send the unflattened BatchedTensors into the desired function.

Performance
-----------
Some benchmarking using
```
import torch
def foo(a, b, c, d):
    return a, b, c, d

x = torch.randn(2, 3)
foo_vmap = torch.vmap(foo)
%timeit foo_vmap(x, x, x, x)
```
shows a slowdown from 15us to 25us on my machine. The 10us overhead is
not a lot, especially since our vmap implementation is a "prototype". We
can work around the performance in the future by either moving part of
the pytree implementation into C++ or depending on a library that has a
performant pytree implementation.

Test Plan
---------
- New tests, also updated old tests.

Test Plan: Imported from OSS

Reviewed By: heitorschueroff

Differential Revision: D24392892

Pulled By: zou3519

fbshipit-source-id: 072b21dcc6065ab43cfd341e84a01a5cc8ec3daf
2020-10-20 07:52:17 -07:00
Richard Zou
f96cb9de79 vmap: added fallback for in-place operators (#46191)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46191

This PR adds a fallback for in-place operators to vmap. We define an
in-place operator to be an operator that operators in-place on its first
argument and returns the first argument.

The "iteration over batch" logic is mostly copied from the out-of-place
vmap fallback. I wanted to try to not copy this but the iteration logic
is pretty entangled with the rest of the logic; one alternative was to
use if/else statements inside batchedTensorForLoopFallback but then
there are ~3-4 different sites where we would need that.

When in-place operations are not possible
=========================================
Sometimes, an in-place operation inside of vmap is not possible. For
example, `vmap(Tensor.add_, (None, 0))(torch.rand(3), torch.rand(B0, 3))`
is not possible because the tensor being written to in-place has size
[3] and the other tensor has size [B0, 3].

We detect if this is the case and error out inside the in-place
fallback.

Test Plan
=========
Added some new tests to `test_vmap.py`.

Test Plan: Imported from OSS

Reviewed By: malfet

Differential Revision: D24335240

Pulled By: zou3519

fbshipit-source-id: 1f60346059040dc226f0aeb80a64d9458208fd3e
2020-10-15 15:21:25 -07:00
Kurt Mohler
ef4817fe5a Add tensor_split function, based on numpy.array_split (#45168)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/9382

Pull Request resolved: https://github.com/pytorch/pytorch/pull/45168

Reviewed By: ngimel

Differential Revision: D24166164

Pulled By: mruberry

fbshipit-source-id: 795459821e52885bc99623a01a2abec060995ce6
2020-10-07 23:14:48 -07:00
Richard Zou
1cd5ba49c6 Add batching rule for "is_complex", "conj" (#44649)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44649

To unblock #43208, which adds "is_complex" checks to backward formulas
that are being tested for batched gradient support with vmap.

Test Plan: - `pytest test/test_vmap.py -v`

Reviewed By: anjali411

Differential Revision: D23685356

Pulled By: zou3519

fbshipit-source-id: 29e41a9296336f6d1008e3040cade4c643bf5ebf
2020-09-16 12:19:46 -07:00
Richard Zou
07cba8b1fc Run vmap tests in CI (#44656)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44656

All this time, test_vmap wasn't running in the CI. Fortunately all the
tests pass locally for me. h/t to anjali411 for pointing this out.

Test Plan: - Wait for CI

Reviewed By: anjali411

Differential Revision: D23689355

Pulled By: zou3519

fbshipit-source-id: 543c3e6aed0af77bfd6ea7a7549337f8230e3d32
2020-09-15 10:59:00 -07:00
Richard Zou
e2bb34e860 Batched grad support for: slice, select, diagonal (#44505)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44505

Added batching rules for slice_backward, select_backward, and
diagonal_backward.

Test Plan: - new tests: `pytest test/test_vmap.y -v -k "BatchedGrad"`

Reviewed By: agolynski, anjali411

Differential Revision: D23650409

Pulled By: zou3519

fbshipit-source-id: e317609d068c88ee7bc07fab88b2b3acb8fad7e1
2020-09-11 14:59:58 -07:00
Richard Zou
7632484000 Add some batched gradient tests (#44494)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44494

These tests check (most) operations that are useful for bayesian logistic
regression (BLR) models. Said operators are basically those found in the
log_prob functions of Distributions objects. This PR is not a general,
structured solution for testing batched gradients (see "Alternative
solution" for that), but I wanted to test a small subset of operations
to confirm that the BLR use case works.

There will be follow-up PRs implementing support for some missing
operations for the BLR use case.

Alternative solution
=====================

Ideally, and in the future, I want to autogenerate tests from
common_method_invocations and delete all of the manual tests
introduced by this PR. However, if we were to do this now,
we would need to store the following additional metadata somewhere:
- operator name, supports_batched_grad, allow_vmap_fallback_usage

We could store that metadata as a separate table from
common_method_invocations, or add two columns to
common_method_invocations. Either way that seems like a lot of work and
the situation will get better once vmap supports batched gradients for
all operators (on the fallback path).

I am neutral between performing the alternative approach now v.s. just
manually writing out some tests for these operations, so I picked the
easier approach. Please let me know if you think it would be better to
pursue the alternative approach now.

Test Plan: - `pytest test/test_vmap.py -v -k "BatchedGrad"`

Reviewed By: anjali411

Differential Revision: D23650408

Pulled By: zou3519

fbshipit-source-id: 2f26c7ad4655318a020bdaab5c767cd3956ea5eb
2020-09-11 14:59:54 -07:00
Richard Zou
b6e2b1eac7 BatchedFallback: stop emitting the entire schema in the fallback warning (#44051)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44051

Instead, just emit the operator name. The entire schema is pretty wordy
and doesn't add any additional information.

Test Plan: - modified test: `pytest test/test_vmap.py -v`

Reviewed By: ezyang

Differential Revision: D23481184

Pulled By: zou3519

fbshipit-source-id: 9fbda61fc63565507b04c8b87e0e326a2036effa
2020-09-03 08:33:51 -07:00
Richard Zou
9b98bcecfa torch.cat and torch.stack batching rules (#43798)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43798

These are relatively straightforward.

Test Plan: - `pytest test/test_vmap.py -v`

Reviewed By: ezyang

Differential Revision: D23405000

Pulled By: zou3519

fbshipit-source-id: 65c78da3dee43652636bdb0a65b636fca69e765d
2020-09-01 08:12:46 -07:00
Richard Zou
dbc4218f11 Batching rules for: torch.bmm, torch.dot (#43781)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43781

Test Plan: - `pytest test/test_vmap.py -v`

Reviewed By: ezyang

Differential Revision: D23400843

Pulled By: zou3519

fbshipit-source-id: a901bba6dc2d8435d314cb4dac85bbd5cd4ee2a5
2020-09-01 08:12:43 -07:00
Richard Zou
fa12e225d3 Batching rule for torch.mv (#43780)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43780

The general strategy is:
- unsqueeze the physical inputs enough
- pass the unsqueezed physical inputs to at::matmul
- squeeze any extra dimensions

Test Plan: - `pytest test/test_vmap.py -v`

Reviewed By: ezyang

Differential Revision: D23400842

Pulled By: zou3519

fbshipit-source-id: c550eeb935747c08e3b083609ed307a4374b9096
2020-09-01 08:12:41 -07:00
Richard Zou
2789a4023b TestVmapOperators: add structured tests that batching rules get invoked (#43731)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43731

After this PR, for each test in TestVmapOperators, TestVmapOperators
tests that the test never invokes the slow vmap fallback path. The
rationale behind this change is that TestVmapOperators is used for
testing batching rules and we want confidence that the batching rules
actually get invoked.

We set this up using a similar mechanism to the CUDA memory leak check:
(bff741a849/torch/testing/_internal/common_utils.py (L506-L511))

This PR also implements the batching rule for `to.dtype_layout`; the new
testing caught that we were testing vmap on `to.dtype_layout` but it
didn't actually have a batching rule implemented!

Test Plan: - New tests in `pytest test/test_vmap.py -v` that test the mechanism.

Reviewed By: ezyang

Differential Revision: D23380729

Pulled By: zou3519

fbshipit-source-id: 6a4b97a7fa7b4e1c5be6ad80d6761e0d5b97bb8c
2020-09-01 08:11:35 -07:00
Richard Zou
1cdb9d2ab5 Test runner for batched gradient computation with vmap (#43664)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43664

This PR implements the test runner for batched gradient computation with
vmap. It also implements the batching rule for sigmoid_backward and
tests that one can compute batched gradients with sigmoid (and batched
2nd gradients).

Test Plan: - New tests: `python test/test_vmap.py -v`

Reviewed By: ezyang

Differential Revision: D23358555

Pulled By: zou3519

fbshipit-source-id: 7bb05b845a41b638b7cca45a5eff1fbfb542a51f
2020-08-31 08:21:41 -07:00
Richard Zou
b3f8834033 Batching rule for torch.pow, torch.result_type (#43515)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43515

This PR adds a batching rule for torch.pow. This required adding a
batching rule for torch.result_type.

Test Plan: - added new tests: `pytest test/test_vmap.py -v`

Reviewed By: cpuhrsch

Differential Revision: D23302737

Pulled By: zou3519

fbshipit-source-id: 2cade358750f6cc3abf45f81f2394900600927cc
2020-08-25 17:55:53 -07:00
Richard Zou
c972e6232a Implement batching rules for basic arithmetic ops (#43362)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43362

Batching rules implemented for: addition subtraction division
multiplication.

I refactored the original `mul_batching_rule` into a templated function
so that one can insert arbitrary binary operations into it.

add, sub, rsub, mul, and div all work the same way. However, other
binary operations work slightly differently (I'm still figuring out the
differences and why they're different) so those may need a different
implementation.

Test Plan: - "pytest test/test_vmap.py -v": new tests

Reviewed By: ezyang

Differential Revision: D23252317

Pulled By: zou3519

fbshipit-source-id: 6d36cd837a006a2fd31474469323463c1bd797fc
2020-08-24 08:43:36 -07:00
Richard Zou
c66ca7a48d vmap: Fix bug with x * 0.1 (#43218)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43218

Previously, `vmap(lambda x: x * 0.1)(torch.ones(3))` would return a
float64 tensor(!!). This is because there is a subtle bug in the
batching rule: the batching rule receives:
- A batched tensor for x
- a scalar tensor: tensor(0.1, dtype=torch.float64).
The batching rule decides to expand the scalar tensor to be the same
size as x and then multiplies the two tensors, promoting the output to
be a float64 tensor. However, this isn't correct: we should treat the
scalar tensor like a scalar tensor. When adding a FloatTensor to a
Double scalar tensor, we don't promote the type usually.

Another example of a bug this PR fixes is the following:
`vmap(torch.mul)(torch.ones(3), torch.ones(3, dtype=torch.float64))`
Multiplying a scalar float tensor with a scalar double tensor produces a
float tensor, but the above produced a float64 before this PR due to
mistakingly type-promoting the tensors.

Test Plan:
- new test: `pytest test/test_vmap.py -v`
- I refactored some tests a bit.

Reviewed By: cpuhrsch

Differential Revision: D23195418

Pulled By: zou3519

fbshipit-source-id: 33b7da841e55b47352405839f1f9445c4e0bc721
2020-08-20 13:44:31 -07:00
Richard Zou
7d10298067 Implement Tensor.to batching rule (#43206)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43206

The batching rule is the same as the unary pointwise batching rules:
given a BatchedTensor, we unwrap it, call Tensor.to, and then re-wrap
it.

Test Plan: - `pytest test/test_vmap.py -v -k`

Reviewed By: ezyang

Differential Revision: D23189053

Pulled By: zou3519

fbshipit-source-id: 51b4e41b1cd34bd082082ec4fff3c643002edbaf
2020-08-19 10:54:26 -07:00
Richard Zou
37252e8f00 Implement batching rules for some unary ops (#43059)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43059

This PR implements batching rules for some unary ops. In particular, it
implements the batching rules for the unary ops that take a single
tensor as input (and nothing else).

The batching rule for a unary op is:
(1) grab the physical tensor straight out of the BatchedTensor
(2) call the unary op
(3) rewrap the physical tensor in a BatchedTensor

Test Plan: - new tests `pytest test/test_vmap.py -v -k "Operators"`

Reviewed By: ezyang

Differential Revision: D23132277

Pulled By: zou3519

fbshipit-source-id: 24b9d7535338207531d767155cdefd2c373ada77
2020-08-17 13:38:10 -07:00
Richard Zou
768c2a8c25 vmap: fixed to work with functools.partial (#43028)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43028

There was a bug where we always tried to grab the `__name__` attribute of
the function passed in by the user. Not all Callables have the
`__name__` attribute, an example being a Callable produced by
functools.partial.

This PR modifies the error-checking code to use `repr` if `__name__` is
not available. Furthermore, it moves the "get the name of this function"
functionality to the actual error sites as an optimization so we don't
spend time trying to compute `__repr__` for the Callable if there is no
error.

Test Plan: - `pytest test/test_vmap.py -v`, added new tests.

Reviewed By: yf225

Differential Revision: D23130235

Pulled By: zou3519

fbshipit-source-id: 937f3640cc4d759bf6fa38b600161f5387a54dcf
2020-08-17 13:36:49 -07:00
Richard Zou
bda0007620 Improve calling backward() and grad() inside vmap error messages (#42876)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42876

Previously, the error messages were pretty bad. This PR adds nice
error messages for the following cases:
- user attempts to call .backward() inside vmap for any reason
whatsoever
- user attempts to call autograd.grad(outputs, inputs, grad_outputs),
where outputs or inputs is being vmapped over (so they are
BatchedTensors).

The case we do support is calling autograd.grad(outputs, inputs,
grad_outputs) where `grad_outputs` is being vmapped over. This is the
case for batched gradient support (e.g., user passes in a batched
grad_output).

Test Plan: - new tests: `pytest test/test_vmap.py -v`

Reviewed By: ezyang

Differential Revision: D23059836

Pulled By: zou3519

fbshipit-source-id: 2fd4e3fd93f558e67e2f0941b18f0d00d8ab439f
2020-08-12 10:05:31 -07:00
Richard Zou
e8f4b04d9a vmap: temporarily disable support for random functions (#42617)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42617

While we figure out the random plan, I want to initially disable
support for random operations. This is because there is an ambiguity in
what randomness means. For example,

```
tensor = torch.zeros(B0, 1)
vmap(lambda t: t.normal_())(tensor)
```

in the above example, should tensor[0] and tensor[1] be equal (i.e.,
use the same random seed), or should they be different?

The mechanism for disabling random support is as follows:
- We add a new dispatch key called VmapMode
- Whenever we're inside vmap, we enable VmapMode for all tensors.
This is done via at::VmapMode::increment_nesting and
at::VmapMode::decrement_nesting.
- DispatchKey::VmapMode's fallback kernel is the fallthrough kernel.
- We register kernels that raise errors for all random functions on
DispatchKey::VmapMode. This way, whenever someone calls a random
function on any tensor (not just BatchedTensors) inside of a vmap block,
an error gets thrown.

Test Plan: - pytest test/test_vmap.py -v -k "Operators"

Reviewed By: ezyang

Differential Revision: D22954840

Pulled By: zou3519

fbshipit-source-id: cb8d71062d4087e10cbf408f74b1a9dff81a226d
2020-08-11 07:19:51 -07:00
Richard Zou
8f67c7a624 BatchedTensor fallback: extended to support ops with multiple Tensor returns (#42628)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42628

This PR extends the BatchedTensor fallback to support operators with
multiple Tensor returns. If an operator has multiple returns, we stack
shards of each return to create the full outputs.

Test Plan:
- `pytest test/test_vmap.py -v`. Added a new test for an operator with
multiple returns (torch.var_mean).

Reviewed By: izdeby

Differential Revision: D22957095

Pulled By: zou3519

fbshipit-source-id: 5c0ec3bf51283cc4493b432bcfed1acf5509e662
2020-08-10 17:42:03 -07:00
Richard Zou
f3e8fff0d2 Batching rules for: chunk, split, unbind (#42480)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42480

These are grouped together because they all return a tuple of multiple
tensors.

This PR implements batching rules for chunk, split, and unbind. It also
updates the testing logic. Previously, reference_vmap was not able to
handle multiple outputs, now, it does.

Test Plan: - `pytest test/test_vmap.py -v -k "Operators"`

Reviewed By: ezyang

Differential Revision: D22905401

Pulled By: zou3519

fbshipit-source-id: 9963c943d035e9035c866be74dbdf7ab1989f8c4
2020-08-04 08:33:43 -07:00
Richard Zou
f1d7f001b9 Batching rules for: torch.movedim, torch.narrow, Tensor.unfold (#42474)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42474

Test Plan: - `pytest test/test_vmap.py -v -k "Operators"`

Reviewed By: ezyang

Differential Revision: D22903513

Pulled By: zou3519

fbshipit-source-id: 06b3fb0c7d12b9a045c73a5c5a4f4e3207e07b02
2020-08-04 08:33:41 -07:00