Handle implicit real->complex casting for backward of stack (#84993)

Fixes: #75852

P.S.: Yay for the PyTorch foundation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84993
Approved by: https://github.com/soulitzer
This commit is contained in:
Thomas Viehmann 2022-09-19 21:20:34 +00:00 committed by PyTorch MergeBot
parent cd7408e950
commit e41d758e26
5 changed files with 62 additions and 19 deletions

View File

@ -344,11 +344,15 @@ TEST(InferenceModeTest, TestMixInferenceAndNormalTensorFunctionalOp) {
c.mul(s), "Inference tensors cannot be saved for backward.");
// Inference tensor in TensorList input
// stack does not capture anymore, so disabled
// TODO: find alternative Function that captures a list (maybe custom fn)
/*
std::vector<torch::Tensor> inputs = {s, c};
ASSERT_THROWS_WITH(
torch::stack(inputs), // go through kernels: VariableType(ERROR)!,
// ADInplaceOrView(fallthrough), CPU
"Inference tensors cannot be saved for backward.")
*/
}
}
}

View File

@ -5183,22 +5183,29 @@ for shape in [(1,), ()]:
# Please help update this test if you update the names of any the fields we check!
#
a = torch.ones(1, requires_grad=True)
b = torch.ones(1, requires_grad=True)
out = torch.stack([a, b], dim=0)
self.assertEqual(out.grad_fn._saved_tensors, (a, b)) # TensorList -> Tuple[Tensor]
self.assertIsInstance(out.grad_fn._saved_tensors[0], torch.Tensor)
self.assertIsInstance(out.grad_fn._raw_saved_tensors[0], torch._C._autograd.SavedTensor)
self.assertEqual(out.grad_fn._saved_dim, 0) # int64_t -> int
self.assertIsInstance(out.grad_fn._saved_dim, int)
b = torch.zeros(1, requires_grad=True)
out1 = torch.stack([a, b], dim=0)
out2 = (a * 2) * b
# TODO: I don't think we have a backward saving a list of tensors
# at the moment. It used to be stack, but for no reason...
# see discussion in #84993
# self.assertEqual(out.grad_fn._saved_tensors, (a, b)) # TewnsorList -> Tuple[Tensor]
self.assertEqual(out2.grad_fn._saved_self, a * 2)
self.assertIsInstance(out2.grad_fn._saved_self, torch.Tensor)
self.assertIsInstance(out2.grad_fn._raw_saved_self, torch._C._autograd.SavedTensor)
self.assertEqual(out1.grad_fn._saved_dim, 0) # int64_t -> int
self.assertIsInstance(out1.grad_fn._saved_dim, int)
out.grad_fn._raw_saved_tensors[0].register_hooks(lambda x: x, lambda x: x)
out2.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x)
out.sum().backward()
out2.sum().backward()
with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
out.grad_fn._saved_tensors
with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
out.grad_fn._raw_saved_tensors
self.assertEqual(out.grad_fn._saved_dim, 0)
out2.grad_fn._saved_self
# TODO: interestingly, this only happens if indexing into a list grad_fn._raw_saved_tensors[0],
# not when using a saved tensor, see discussion in #84993
# with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
# out2.grad_fn._raw_saved_self
self.assertEqual(out1.grad_fn._saved_dim, 0)
a = torch.ones(2, 2, requires_grad=True)
indices = torch.tensor([0, 1])
@ -8732,10 +8739,12 @@ class TestAutogradInferenceMode(TestCase):
with self.assertRaisesRegex(RuntimeError, err_msg):
c * s
# inference tensor in TensorList input
inputs = [s, c]
with self.assertRaisesRegex(RuntimeError, err_msg):
torch.stack(inputs)
# TODO: Test this with an autograd.Function when it works
# stack stopped capturing a TensorList input
# # inference tensor in TensorList input
# inputs = [s, c]
# with self.assertRaisesRegex(RuntimeError, err_msg):
# torch.stack(inputs)
def test_mix_inference_and_normal_tensor_inplace_op(self):
@ -9050,16 +9059,23 @@ class TestMultithreadAutograd(TestCase):
# TODO(@anjali411): add an OpInfo based test for torch.cat
# Issue: https://github.com/pytorch/pytorch/issues/51627
def test_cat_r_to_c(self):
# https://github.com/pytorch/pytorch/issues/75852
def test_cat_stack_r_to_c(self):
inp_c = torch.rand(3, 2, dtype=torch.cdouble, requires_grad=True)
inp_r = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
def fn(x1, x2):
return torch.cat((x1, x2), dim=-1)
def fn2(x1, x2):
return torch.stack((x1, x2), dim=-1)
torch.autograd.gradcheck(fn, [inp_r, inp_c], check_forward_ad=True)
torch.autograd.gradcheck(fn, [inp_c, inp_r], check_forward_ad=True)
torch.autograd.gradcheck(fn2, [inp_r, inp_c], check_forward_ad=True)
torch.autograd.gradcheck(fn2, [inp_c, inp_r], check_forward_ad=True)
class TestAutogradMultipleDispatch(TestCase):
def test_autograd_multiple_dispatch_registrations(self, device):
t = torch.randn(3, 3, device=device, requires_grad=True)

View File

@ -2746,7 +2746,7 @@
result: auto_linear
- name: stack(Tensor[] tensors, int dim=0) -> Tensor
tensors: "grad.defined() ? unbind(grad, dim) : std::vector<Tensor>(tensors.size())"
tensors: stack_tensors_backward(grad, dim, to_args_scalartypes(tensors))
result: stack_jvp(tensors, dim)
# fused RNN kernels

View File

@ -894,6 +894,25 @@ std::vector<Tensor> cat_tensors_backward(
return grad_inputs;
}
std::vector<Tensor> stack_tensors_backward(
const Tensor& grad,
int64_t dim,
const std::vector<ScalarType>& dtypes) {
std::vector<Tensor> grad_inputs(dtypes.size());
if (!grad.defined()) {
return grad_inputs;
}
bool grad_is_complex = grad.is_complex();
for (const auto i : c10::irange(dtypes.size())) {
auto gr = grad.select(dim, i);
if (grad_is_complex && !at::isComplexType(dtypes[i])) {
gr = at::real(gr);
}
grad_inputs[i] = gr;
}
return grad_inputs;
}
std::vector<Tensor> block_diag_backward(
const Tensor& grad,
const std::vector<std::vector<int64_t>>& sizes,

View File

@ -225,6 +225,10 @@ std::vector<at::Tensor> cat_tensors_backward(
const std::vector<std::vector<int64_t>>& sizes,
const std::vector<ScalarType>& dtypes,
int64_t dim);
std::vector<at::Tensor> stack_tensors_backward(
const at::Tensor& grad,
int64_t dim,
const std::vector<ScalarType>& dtypes);
std::vector<at::Tensor> block_diag_backward(
const at::Tensor& grad,
const std::vector<std::vector<int64_t>>& sizes,