mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support undefined grads in vmap fallback (#46671)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46671 Previously, the vmap fallback would choke whenever it saw an undefined tensor. For each sample in a batch, the fallback runs an operator and then stacks together outputs to get the actual output. Undefined tensors can occur as outputs while computing batched gradients with vmap. This PR updates the vmap fallback to handle undefined tensors which can appear in backward formulas: - if for each sample in a batch the output was undefined, then the vmap fallback returns an undefined tensor - if for each sample in a batch the output is defined, then the vmap fallback stacks together the defined tensors - if for some samples in a batch the output is defined/undefined, then we error out. Test Plan: - new tests Reviewed By: ezyang Differential Revision: D24454909 Pulled By: zou3519 fbshipit-source-id: d225382fd17881f23c9833323b68834cfef351f3
This commit is contained in:
parent
85954164a4
commit
aa828bf084
|
|
@ -195,6 +195,32 @@ void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::j
|
|||
torch::jit::push(stack, self);
|
||||
}
|
||||
|
||||
static Tensor safeStack(TensorList tensors) {
|
||||
auto is_defined = [](const Tensor& t) { return t.defined(); };
|
||||
if (std::all_of(tensors.begin(), tensors.end(), is_defined)) {
|
||||
return at::stack(tensors);
|
||||
}
|
||||
// NOTE [vmap through backward and undefined grad]
|
||||
// While vmapping through backward functions (to compute batched grad), it
|
||||
// is possible for the backward function to return an undefined grad for some
|
||||
// grad_input for each example. In that case, we return an undefined grad.
|
||||
//
|
||||
// It is theoretically posssible for *some* of the examples to produce an
|
||||
// undefined grad (a kernel could peek at the gradient values and return an
|
||||
// undefined tensor if it determines the gradient is full of zeros). We
|
||||
// could handle this by treating the undefined grad as a zero-filled tensor
|
||||
// of the correct shape while stacking the tensors together. However I expect
|
||||
// this to happen very rarely (I have not been able to find an example in our
|
||||
// codebase) so we just error out in this case.
|
||||
if (std::none_of(tensors.begin(), tensors.end(), is_defined)) {
|
||||
return Tensor();
|
||||
}
|
||||
TORCH_CHECK(false,
|
||||
"vmap: slow fallback received a mix of undefined and defined tensors ",
|
||||
"as the result of an operation. This is not supported, please file us ",
|
||||
"an issue on github.");
|
||||
}
|
||||
|
||||
// The general flow of the algorithm is as follows.
|
||||
// - First, we figure out which arguments are BatchedTensors and save them
|
||||
// to a vector. We also store a vector of which index of the arguments list
|
||||
|
|
@ -318,7 +344,12 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta
|
|||
auto output_shards_chunks = MatrixRef<Tensor>(output_shards, num_batches);
|
||||
for (int64_t return_idx = 0; return_idx < num_returns; ++return_idx) {
|
||||
auto shards = output_shards_chunks[return_idx];
|
||||
auto flat_output = at::stack(shards);
|
||||
auto flat_output = safeStack(shards);
|
||||
// See NOTE [vmap through backward and undefined grad]
|
||||
if (!flat_output.defined()) {
|
||||
torch::jit::push(stack, flat_output);
|
||||
continue;
|
||||
}
|
||||
VmapDimVector output_sizes(batch_sizes);
|
||||
output_sizes.insert(
|
||||
output_sizes.end(),
|
||||
|
|
|
|||
|
|
@ -733,6 +733,25 @@ class TestVmapAPI(TestCase):
|
|||
result = vmap(model)(tensor)
|
||||
self.assertEqual(result, model(tensor))
|
||||
|
||||
def test_fallback_with_undefined_grad(self):
|
||||
B0 = 7
|
||||
x = torch.randn(2, 3, 4, 5, requires_grad=True)
|
||||
weight = torch.randn(3, 3, 1, 1)
|
||||
v = torch.randn(B0, 2, 3, 4, 5)
|
||||
|
||||
def get_vjp(v):
|
||||
result = torch.nn.functional.conv2d(x, weight)
|
||||
grad_x, = torch.autograd.grad(result, x, v)
|
||||
return grad_x
|
||||
|
||||
# Runs vmap(get_vjp)(v), which should not error out.
|
||||
# The backward formula for convolution returns an undefined
|
||||
# Tensor for grad_bias because the original bias does not exist.
|
||||
#
|
||||
# In the future we'll probably add a batching rule for convolution
|
||||
# backward. When this happens, we should modify this test to use a
|
||||
# different op (and/or create and use a dummy operator) to avoid bitrot.
|
||||
self._assert_uses_vmap_fallback([get_vjp], [v])
|
||||
|
||||
def slice_inputs(inputs, bdims, i):
|
||||
result = []
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user