mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[fix] take : backward batching rule (#95772)
Fixes https://github.com/pytorch/pytorch/issues/95738 Pull Request resolved: https://github.com/pytorch/pytorch/pull/95772 Approved by: https://github.com/zou3519
This commit is contained in:
parent
7d5d5beba2
commit
ffd76d11c9
|
|
@ -1646,6 +1646,17 @@ class TestJac(TestCase):
|
||||||
expected = expected.view(2, 3, 2, 3)
|
expected = expected.view(2, 3, 2, 3)
|
||||||
assert torch.allclose(y, expected)
|
assert torch.allclose(y, expected)
|
||||||
|
|
||||||
|
@jacrev_and_jacfwd
|
||||||
|
def test_take(self, device, jacapi):
|
||||||
|
x = torch.rand(5)
|
||||||
|
|
||||||
|
def func(x):
|
||||||
|
y = torch.ones(3, dtype=torch.long)
|
||||||
|
z = torch.take(x, y)
|
||||||
|
return z
|
||||||
|
|
||||||
|
self.assertEqual(jacrev(func)(x), torch.autograd.functional.jacobian(func, x))
|
||||||
|
|
||||||
@FIXME_jacrev_only
|
@FIXME_jacrev_only
|
||||||
def test_diff_numel(self, device, jacapi):
|
def test_diff_numel(self, device, jacapi):
|
||||||
x = torch.randn(2, 4, device=device)
|
x = torch.randn(2, 4, device=device)
|
||||||
|
|
@ -2172,26 +2183,38 @@ class TestJac(TestCase):
|
||||||
def test_chunk_jacrev_chunksize_one(self, device, _preallocate_and_copy):
|
def test_chunk_jacrev_chunksize_one(self, device, _preallocate_and_copy):
|
||||||
# With chunk_size=1, we shouldn't `vmap` and hence not be limited
|
# With chunk_size=1, we shouldn't `vmap` and hence not be limited
|
||||||
# by it's constraints.
|
# by it's constraints.
|
||||||
|
x = torch.randn(3, 3, device=device)
|
||||||
|
|
||||||
x = torch.randn(3, device=device)
|
# Function with Dynamic Op in Backward.
|
||||||
idx_1 = torch.tensor([0, ], device=device)
|
# This should cause jacrev/vmap(vjp) to fail.
|
||||||
idx_2 = torch.tensor([0, 1], device=device)
|
class IdentityWithDynamicBackwardOp(torch.autograd.Function):
|
||||||
chunk_size = 1
|
@staticmethod
|
||||||
|
def forward(input):
|
||||||
|
return input
|
||||||
|
|
||||||
def f(x, idx):
|
@staticmethod
|
||||||
# `take` doesn't work with vmap
|
def setup_context(ctx, inputs, output):
|
||||||
# as it returns an output with dynamic shape.
|
pass
|
||||||
return torch.take(x, idx)
|
|
||||||
|
|
||||||
for fn, idx in ((f, idx_1), (f, idx_2)):
|
@staticmethod
|
||||||
jacfn = jacrev(fn, chunk_size=chunk_size, _preallocate_and_copy=_preallocate_and_copy)
|
def backward(ctx, grad_output):
|
||||||
actual = jacfn(x, idx)
|
# dynamic op in backward pass.
|
||||||
expected = torch.autograd.functional.jacobian(partial(fn, idx=idx), x, vectorize=False)
|
grad_output.nonzero()
|
||||||
self.assertEqual(actual, expected)
|
return grad_output
|
||||||
|
|
||||||
msg = r"vmap: .* is not possible because there exists a Tensor"
|
def f(x):
|
||||||
with self.assertRaisesRegex(RuntimeError, msg):
|
return IdentityWithDynamicBackwardOp.apply(x)
|
||||||
jacrev(fn, chunk_size=2, _preallocate_and_copy=_preallocate_and_copy)(x, idx)
|
|
||||||
|
# With `chunk_size=1`, we don't use vmap. So the following should work.
|
||||||
|
jacfn = jacrev(f, chunk_size=1, _preallocate_and_copy=_preallocate_and_copy)
|
||||||
|
actual = jacfn(x)
|
||||||
|
expected = torch.autograd.functional.jacobian(f, x, vectorize=False)
|
||||||
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
|
# Should fail with `chunk_size=2`.
|
||||||
|
msg = r"vmap: We do not support batching operators that can output dynamic shape."
|
||||||
|
with self.assertRaisesRegex(RuntimeError, msg):
|
||||||
|
jacrev(f, chunk_size=2, _preallocate_and_copy=_preallocate_and_copy)(x)
|
||||||
|
|
||||||
def test_complex_error(self, device):
|
def test_complex_error(self, device):
|
||||||
# Verify complex input raises error
|
# Verify complex input raises error
|
||||||
|
|
|
||||||
|
|
@ -788,7 +788,6 @@ class TestOperators(TestCase):
|
||||||
xfail("normal"), # calls random op
|
xfail("normal"), # calls random op
|
||||||
xfail("normal", "number_mean"), # calls random op
|
xfail("normal", "number_mean"), # calls random op
|
||||||
xfail("pca_lowrank"), # calls random op
|
xfail("pca_lowrank"), # calls random op
|
||||||
xfail("put"), # vmap: inplace into a regular tensor
|
|
||||||
# https://github.com/pytorch/pytorch/issues/96560
|
# https://github.com/pytorch/pytorch/issues/96560
|
||||||
decorate('linalg.pinv', 'hermitian', decorator=skipIfRocm),
|
decorate('linalg.pinv', 'hermitian', decorator=skipIfRocm),
|
||||||
xfail("quantile", device_type='cpu'), # Batching rule not implemented for `at::equal`
|
xfail("quantile", device_type='cpu'), # Batching rule not implemented for `at::equal`
|
||||||
|
|
@ -882,7 +881,6 @@ class TestOperators(TestCase):
|
||||||
xfail('masked_scatter'), # dynamic
|
xfail('masked_scatter'), # dynamic
|
||||||
xfail('nn.functional.fractional_max_pool2d'), # random
|
xfail('nn.functional.fractional_max_pool2d'), # random
|
||||||
xfail('nn.functional.fractional_max_pool3d'), # random
|
xfail('nn.functional.fractional_max_pool3d'), # random
|
||||||
xfail('take'), # dynamic
|
|
||||||
xfail('pca_lowrank', ''), # randomness
|
xfail('pca_lowrank', ''), # randomness
|
||||||
xfail('svd_lowrank', ''), # randomness
|
xfail('svd_lowrank', ''), # randomness
|
||||||
xfail('to_sparse', ''), # non-dense output
|
xfail('to_sparse', ''), # non-dense output
|
||||||
|
|
|
||||||
|
|
@ -6702,9 +6702,9 @@ Tensor take_backward(
|
||||||
const Tensor& indices) {
|
const Tensor& indices) {
|
||||||
Tensor grad_self = at::zeros_like(self);
|
Tensor grad_self = at::zeros_like(self);
|
||||||
// For Composite Compliance,
|
// For Composite Compliance,
|
||||||
// if `grad` and `indices` are CCT but `self` is not
|
// if `grad` and `indices` are CCT but `grad_self` is not
|
||||||
// then we use the out-of-place variant of `put`.
|
// then we use the out-of-place variant of `put`.
|
||||||
if (!isTensorSubclassLike(self) &&
|
if (!isTensorSubclassLike(grad_self) &&
|
||||||
areAnyTensorSubclassLike({grad, indices})) {
|
areAnyTensorSubclassLike({grad, indices})) {
|
||||||
return grad_self.put(indices, grad, true);
|
return grad_self.put(indices, grad, true);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user