[fix] take : backward batching rule (#95772)

Fixes https://github.com/pytorch/pytorch/issues/95738

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95772
Approved by: https://github.com/zou3519
This commit is contained in:
kshitij12345 2023-03-30 17:18:17 +00:00 committed by PyTorch MergeBot
parent 7d5d5beba2
commit ffd76d11c9
3 changed files with 41 additions and 20 deletions

View File

@ -1646,6 +1646,17 @@ class TestJac(TestCase):
expected = expected.view(2, 3, 2, 3)
assert torch.allclose(y, expected)
@jacrev_and_jacfwd
def test_take(self, device, jacapi):
x = torch.rand(5)
def func(x):
y = torch.ones(3, dtype=torch.long)
z = torch.take(x, y)
return z
self.assertEqual(jacrev(func)(x), torch.autograd.functional.jacobian(func, x))
@FIXME_jacrev_only
def test_diff_numel(self, device, jacapi):
x = torch.randn(2, 4, device=device)
@ -2172,26 +2183,38 @@ class TestJac(TestCase):
def test_chunk_jacrev_chunksize_one(self, device, _preallocate_and_copy):
# With chunk_size=1, we shouldn't `vmap` and hence not be limited
# by it's constraints.
x = torch.randn(3, 3, device=device)
x = torch.randn(3, device=device)
idx_1 = torch.tensor([0, ], device=device)
idx_2 = torch.tensor([0, 1], device=device)
chunk_size = 1
# Function with Dynamic Op in Backward.
# This should cause jacrev/vmap(vjp) to fail.
class IdentityWithDynamicBackwardOp(torch.autograd.Function):
@staticmethod
def forward(input):
return input
def f(x, idx):
# `take` doesn't work with vmap
# as it returns an output with dynamic shape.
return torch.take(x, idx)
@staticmethod
def setup_context(ctx, inputs, output):
pass
for fn, idx in ((f, idx_1), (f, idx_2)):
jacfn = jacrev(fn, chunk_size=chunk_size, _preallocate_and_copy=_preallocate_and_copy)
actual = jacfn(x, idx)
expected = torch.autograd.functional.jacobian(partial(fn, idx=idx), x, vectorize=False)
self.assertEqual(actual, expected)
@staticmethod
def backward(ctx, grad_output):
# dynamic op in backward pass.
grad_output.nonzero()
return grad_output
msg = r"vmap: .* is not possible because there exists a Tensor"
with self.assertRaisesRegex(RuntimeError, msg):
jacrev(fn, chunk_size=2, _preallocate_and_copy=_preallocate_and_copy)(x, idx)
def f(x):
return IdentityWithDynamicBackwardOp.apply(x)
# With `chunk_size=1`, we don't use vmap. So the following should work.
jacfn = jacrev(f, chunk_size=1, _preallocate_and_copy=_preallocate_and_copy)
actual = jacfn(x)
expected = torch.autograd.functional.jacobian(f, x, vectorize=False)
self.assertEqual(actual, expected)
# Should fail with `chunk_size=2`.
msg = r"vmap: We do not support batching operators that can output dynamic shape."
with self.assertRaisesRegex(RuntimeError, msg):
jacrev(f, chunk_size=2, _preallocate_and_copy=_preallocate_and_copy)(x)
def test_complex_error(self, device):
# Verify complex input raises error

View File

@ -788,7 +788,6 @@ class TestOperators(TestCase):
xfail("normal"), # calls random op
xfail("normal", "number_mean"), # calls random op
xfail("pca_lowrank"), # calls random op
xfail("put"), # vmap: inplace into a regular tensor
# https://github.com/pytorch/pytorch/issues/96560
decorate('linalg.pinv', 'hermitian', decorator=skipIfRocm),
xfail("quantile", device_type='cpu'), # Batching rule not implemented for `at::equal`
@ -882,7 +881,6 @@ class TestOperators(TestCase):
xfail('masked_scatter'), # dynamic
xfail('nn.functional.fractional_max_pool2d'), # random
xfail('nn.functional.fractional_max_pool3d'), # random
xfail('take'), # dynamic
xfail('pca_lowrank', ''), # randomness
xfail('svd_lowrank', ''), # randomness
xfail('to_sparse', ''), # non-dense output

View File

@ -6702,9 +6702,9 @@ Tensor take_backward(
const Tensor& indices) {
Tensor grad_self = at::zeros_like(self);
// For Composite Compliance,
// if `grad` and `indices` are CCT but `self` is not
// if `grad` and `indices` are CCT but `grad_self` is not
// then we use the out-of-place variant of `put`.
if (!isTensorSubclassLike(self) &&
if (!isTensorSubclassLike(grad_self) &&
areAnyTensorSubclassLike({grad, indices})) {
return grad_self.put(indices, grad, true);
}