mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[fix] take : backward batching rule (#95772)
Fixes https://github.com/pytorch/pytorch/issues/95738 Pull Request resolved: https://github.com/pytorch/pytorch/pull/95772 Approved by: https://github.com/zou3519
This commit is contained in:
parent
7d5d5beba2
commit
ffd76d11c9
|
|
@ -1646,6 +1646,17 @@ class TestJac(TestCase):
|
|||
expected = expected.view(2, 3, 2, 3)
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
@jacrev_and_jacfwd
|
||||
def test_take(self, device, jacapi):
|
||||
x = torch.rand(5)
|
||||
|
||||
def func(x):
|
||||
y = torch.ones(3, dtype=torch.long)
|
||||
z = torch.take(x, y)
|
||||
return z
|
||||
|
||||
self.assertEqual(jacrev(func)(x), torch.autograd.functional.jacobian(func, x))
|
||||
|
||||
@FIXME_jacrev_only
|
||||
def test_diff_numel(self, device, jacapi):
|
||||
x = torch.randn(2, 4, device=device)
|
||||
|
|
@ -2172,26 +2183,38 @@ class TestJac(TestCase):
|
|||
def test_chunk_jacrev_chunksize_one(self, device, _preallocate_and_copy):
|
||||
# With chunk_size=1, we shouldn't `vmap` and hence not be limited
|
||||
# by it's constraints.
|
||||
x = torch.randn(3, 3, device=device)
|
||||
|
||||
x = torch.randn(3, device=device)
|
||||
idx_1 = torch.tensor([0, ], device=device)
|
||||
idx_2 = torch.tensor([0, 1], device=device)
|
||||
chunk_size = 1
|
||||
# Function with Dynamic Op in Backward.
|
||||
# This should cause jacrev/vmap(vjp) to fail.
|
||||
class IdentityWithDynamicBackwardOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(input):
|
||||
return input
|
||||
|
||||
def f(x, idx):
|
||||
# `take` doesn't work with vmap
|
||||
# as it returns an output with dynamic shape.
|
||||
return torch.take(x, idx)
|
||||
@staticmethod
|
||||
def setup_context(ctx, inputs, output):
|
||||
pass
|
||||
|
||||
for fn, idx in ((f, idx_1), (f, idx_2)):
|
||||
jacfn = jacrev(fn, chunk_size=chunk_size, _preallocate_and_copy=_preallocate_and_copy)
|
||||
actual = jacfn(x, idx)
|
||||
expected = torch.autograd.functional.jacobian(partial(fn, idx=idx), x, vectorize=False)
|
||||
self.assertEqual(actual, expected)
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# dynamic op in backward pass.
|
||||
grad_output.nonzero()
|
||||
return grad_output
|
||||
|
||||
msg = r"vmap: .* is not possible because there exists a Tensor"
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
jacrev(fn, chunk_size=2, _preallocate_and_copy=_preallocate_and_copy)(x, idx)
|
||||
def f(x):
|
||||
return IdentityWithDynamicBackwardOp.apply(x)
|
||||
|
||||
# With `chunk_size=1`, we don't use vmap. So the following should work.
|
||||
jacfn = jacrev(f, chunk_size=1, _preallocate_and_copy=_preallocate_and_copy)
|
||||
actual = jacfn(x)
|
||||
expected = torch.autograd.functional.jacobian(f, x, vectorize=False)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
# Should fail with `chunk_size=2`.
|
||||
msg = r"vmap: We do not support batching operators that can output dynamic shape."
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
jacrev(f, chunk_size=2, _preallocate_and_copy=_preallocate_and_copy)(x)
|
||||
|
||||
def test_complex_error(self, device):
|
||||
# Verify complex input raises error
|
||||
|
|
|
|||
|
|
@ -788,7 +788,6 @@ class TestOperators(TestCase):
|
|||
xfail("normal"), # calls random op
|
||||
xfail("normal", "number_mean"), # calls random op
|
||||
xfail("pca_lowrank"), # calls random op
|
||||
xfail("put"), # vmap: inplace into a regular tensor
|
||||
# https://github.com/pytorch/pytorch/issues/96560
|
||||
decorate('linalg.pinv', 'hermitian', decorator=skipIfRocm),
|
||||
xfail("quantile", device_type='cpu'), # Batching rule not implemented for `at::equal`
|
||||
|
|
@ -882,7 +881,6 @@ class TestOperators(TestCase):
|
|||
xfail('masked_scatter'), # dynamic
|
||||
xfail('nn.functional.fractional_max_pool2d'), # random
|
||||
xfail('nn.functional.fractional_max_pool3d'), # random
|
||||
xfail('take'), # dynamic
|
||||
xfail('pca_lowrank', ''), # randomness
|
||||
xfail('svd_lowrank', ''), # randomness
|
||||
xfail('to_sparse', ''), # non-dense output
|
||||
|
|
|
|||
|
|
@ -6702,9 +6702,9 @@ Tensor take_backward(
|
|||
const Tensor& indices) {
|
||||
Tensor grad_self = at::zeros_like(self);
|
||||
// For Composite Compliance,
|
||||
// if `grad` and `indices` are CCT but `self` is not
|
||||
// if `grad` and `indices` are CCT but `grad_self` is not
|
||||
// then we use the out-of-place variant of `put`.
|
||||
if (!isTensorSubclassLike(self) &&
|
||||
if (!isTensorSubclassLike(grad_self) &&
|
||||
areAnyTensorSubclassLike({grad, indices})) {
|
||||
return grad_self.put(indices, grad, true);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user