Support undefined grads in vmap fallback (#46671)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46671

Previously, the vmap fallback would choke whenever it saw an undefined
tensor. For each sample in a batch, the fallback runs an operator
and then stacks together outputs to get the actual output.
Undefined tensors can occur as outputs while computing batched gradients
with vmap.

This PR updates the vmap fallback to handle undefined tensors which can
appear in backward formulas:
- if for each sample in a batch the output was undefined, then the vmap
fallback returns an undefined tensor
- if for each sample in a batch the output is defined, then the vmap
fallback stacks together the defined tensors
- if for some samples in a batch the output is defined/undefined, then
we error out.

Test Plan: - new tests

Reviewed By: ezyang

Differential Revision: D24454909

Pulled By: zou3519

fbshipit-source-id: d225382fd17881f23c9833323b68834cfef351f3
This commit is contained in:
Richard Zou 2020-10-23 14:22:03 -07:00 committed by Facebook GitHub Bot
parent 85954164a4
commit aa828bf084
2 changed files with 51 additions and 1 deletions

View File

@ -195,6 +195,32 @@ void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::j
torch::jit::push(stack, self);
}
static Tensor safeStack(TensorList tensors) {
auto is_defined = [](const Tensor& t) { return t.defined(); };
if (std::all_of(tensors.begin(), tensors.end(), is_defined)) {
return at::stack(tensors);
}
// NOTE [vmap through backward and undefined grad]
// While vmapping through backward functions (to compute batched grad), it
// is possible for the backward function to return an undefined grad for some
// grad_input for each example. In that case, we return an undefined grad.
//
// It is theoretically posssible for *some* of the examples to produce an
// undefined grad (a kernel could peek at the gradient values and return an
// undefined tensor if it determines the gradient is full of zeros). We
// could handle this by treating the undefined grad as a zero-filled tensor
// of the correct shape while stacking the tensors together. However I expect
// this to happen very rarely (I have not been able to find an example in our
// codebase) so we just error out in this case.
if (std::none_of(tensors.begin(), tensors.end(), is_defined)) {
return Tensor();
}
TORCH_CHECK(false,
"vmap: slow fallback received a mix of undefined and defined tensors ",
"as the result of an operation. This is not supported, please file us ",
"an issue on github.");
}
// The general flow of the algorithm is as follows.
// - First, we figure out which arguments are BatchedTensors and save them
// to a vector. We also store a vector of which index of the arguments list
@ -318,7 +344,12 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta
auto output_shards_chunks = MatrixRef<Tensor>(output_shards, num_batches);
for (int64_t return_idx = 0; return_idx < num_returns; ++return_idx) {
auto shards = output_shards_chunks[return_idx];
auto flat_output = at::stack(shards);
auto flat_output = safeStack(shards);
// See NOTE [vmap through backward and undefined grad]
if (!flat_output.defined()) {
torch::jit::push(stack, flat_output);
continue;
}
VmapDimVector output_sizes(batch_sizes);
output_sizes.insert(
output_sizes.end(),

View File

@ -733,6 +733,25 @@ class TestVmapAPI(TestCase):
result = vmap(model)(tensor)
self.assertEqual(result, model(tensor))
def test_fallback_with_undefined_grad(self):
B0 = 7
x = torch.randn(2, 3, 4, 5, requires_grad=True)
weight = torch.randn(3, 3, 1, 1)
v = torch.randn(B0, 2, 3, 4, 5)
def get_vjp(v):
result = torch.nn.functional.conv2d(x, weight)
grad_x, = torch.autograd.grad(result, x, v)
return grad_x
# Runs vmap(get_vjp)(v), which should not error out.
# The backward formula for convolution returns an undefined
# Tensor for grad_bias because the original bias does not exist.
#
# In the future we'll probably add a batching rule for convolution
# backward. When this happens, we should modify this test to use a
# different op (and/or create and use a dummy operator) to avoid bitrot.
self._assert_uses_vmap_fallback([get_vjp], [v])
def slice_inputs(inputs, bdims, i):
result = []