mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Handle implicit real->complex casting for backward of stack (#84993)
Fixes: #75852 P.S.: Yay for the PyTorch foundation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/84993 Approved by: https://github.com/soulitzer
This commit is contained in:
parent
cd7408e950
commit
e41d758e26
|
|
@ -344,11 +344,15 @@ TEST(InferenceModeTest, TestMixInferenceAndNormalTensorFunctionalOp) {
|
|||
c.mul(s), "Inference tensors cannot be saved for backward.");
|
||||
|
||||
// Inference tensor in TensorList input
|
||||
// stack does not capture anymore, so disabled
|
||||
// TODO: find alternative Function that captures a list (maybe custom fn)
|
||||
/*
|
||||
std::vector<torch::Tensor> inputs = {s, c};
|
||||
ASSERT_THROWS_WITH(
|
||||
torch::stack(inputs), // go through kernels: VariableType(ERROR)!,
|
||||
// ADInplaceOrView(fallthrough), CPU
|
||||
"Inference tensors cannot be saved for backward.")
|
||||
*/
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5183,22 +5183,29 @@ for shape in [(1,), ()]:
|
|||
# Please help update this test if you update the names of any the fields we check!
|
||||
#
|
||||
a = torch.ones(1, requires_grad=True)
|
||||
b = torch.ones(1, requires_grad=True)
|
||||
out = torch.stack([a, b], dim=0)
|
||||
self.assertEqual(out.grad_fn._saved_tensors, (a, b)) # TensorList -> Tuple[Tensor]
|
||||
self.assertIsInstance(out.grad_fn._saved_tensors[0], torch.Tensor)
|
||||
self.assertIsInstance(out.grad_fn._raw_saved_tensors[0], torch._C._autograd.SavedTensor)
|
||||
self.assertEqual(out.grad_fn._saved_dim, 0) # int64_t -> int
|
||||
self.assertIsInstance(out.grad_fn._saved_dim, int)
|
||||
b = torch.zeros(1, requires_grad=True)
|
||||
out1 = torch.stack([a, b], dim=0)
|
||||
out2 = (a * 2) * b
|
||||
# TODO: I don't think we have a backward saving a list of tensors
|
||||
# at the moment. It used to be stack, but for no reason...
|
||||
# see discussion in #84993
|
||||
# self.assertEqual(out.grad_fn._saved_tensors, (a, b)) # TewnsorList -> Tuple[Tensor]
|
||||
self.assertEqual(out2.grad_fn._saved_self, a * 2)
|
||||
self.assertIsInstance(out2.grad_fn._saved_self, torch.Tensor)
|
||||
self.assertIsInstance(out2.grad_fn._raw_saved_self, torch._C._autograd.SavedTensor)
|
||||
self.assertEqual(out1.grad_fn._saved_dim, 0) # int64_t -> int
|
||||
self.assertIsInstance(out1.grad_fn._saved_dim, int)
|
||||
|
||||
out.grad_fn._raw_saved_tensors[0].register_hooks(lambda x: x, lambda x: x)
|
||||
out2.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x)
|
||||
|
||||
out.sum().backward()
|
||||
out2.sum().backward()
|
||||
with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
|
||||
out.grad_fn._saved_tensors
|
||||
with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
|
||||
out.grad_fn._raw_saved_tensors
|
||||
self.assertEqual(out.grad_fn._saved_dim, 0)
|
||||
out2.grad_fn._saved_self
|
||||
# TODO: interestingly, this only happens if indexing into a list grad_fn._raw_saved_tensors[0],
|
||||
# not when using a saved tensor, see discussion in #84993
|
||||
# with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
|
||||
# out2.grad_fn._raw_saved_self
|
||||
self.assertEqual(out1.grad_fn._saved_dim, 0)
|
||||
|
||||
a = torch.ones(2, 2, requires_grad=True)
|
||||
indices = torch.tensor([0, 1])
|
||||
|
|
@ -8732,10 +8739,12 @@ class TestAutogradInferenceMode(TestCase):
|
|||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
c * s
|
||||
|
||||
# inference tensor in TensorList input
|
||||
inputs = [s, c]
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
torch.stack(inputs)
|
||||
# TODO: Test this with an autograd.Function when it works
|
||||
# stack stopped capturing a TensorList input
|
||||
# # inference tensor in TensorList input
|
||||
# inputs = [s, c]
|
||||
# with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
# torch.stack(inputs)
|
||||
|
||||
|
||||
def test_mix_inference_and_normal_tensor_inplace_op(self):
|
||||
|
|
@ -9050,16 +9059,23 @@ class TestMultithreadAutograd(TestCase):
|
|||
|
||||
# TODO(@anjali411): add an OpInfo based test for torch.cat
|
||||
# Issue: https://github.com/pytorch/pytorch/issues/51627
|
||||
def test_cat_r_to_c(self):
|
||||
# https://github.com/pytorch/pytorch/issues/75852
|
||||
def test_cat_stack_r_to_c(self):
|
||||
inp_c = torch.rand(3, 2, dtype=torch.cdouble, requires_grad=True)
|
||||
inp_r = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
|
||||
|
||||
def fn(x1, x2):
|
||||
return torch.cat((x1, x2), dim=-1)
|
||||
|
||||
def fn2(x1, x2):
|
||||
return torch.stack((x1, x2), dim=-1)
|
||||
|
||||
torch.autograd.gradcheck(fn, [inp_r, inp_c], check_forward_ad=True)
|
||||
torch.autograd.gradcheck(fn, [inp_c, inp_r], check_forward_ad=True)
|
||||
|
||||
torch.autograd.gradcheck(fn2, [inp_r, inp_c], check_forward_ad=True)
|
||||
torch.autograd.gradcheck(fn2, [inp_c, inp_r], check_forward_ad=True)
|
||||
|
||||
class TestAutogradMultipleDispatch(TestCase):
|
||||
def test_autograd_multiple_dispatch_registrations(self, device):
|
||||
t = torch.randn(3, 3, device=device, requires_grad=True)
|
||||
|
|
|
|||
|
|
@ -2746,7 +2746,7 @@
|
|||
result: auto_linear
|
||||
|
||||
- name: stack(Tensor[] tensors, int dim=0) -> Tensor
|
||||
tensors: "grad.defined() ? unbind(grad, dim) : std::vector<Tensor>(tensors.size())"
|
||||
tensors: stack_tensors_backward(grad, dim, to_args_scalartypes(tensors))
|
||||
result: stack_jvp(tensors, dim)
|
||||
|
||||
# fused RNN kernels
|
||||
|
|
|
|||
|
|
@ -894,6 +894,25 @@ std::vector<Tensor> cat_tensors_backward(
|
|||
return grad_inputs;
|
||||
}
|
||||
|
||||
std::vector<Tensor> stack_tensors_backward(
|
||||
const Tensor& grad,
|
||||
int64_t dim,
|
||||
const std::vector<ScalarType>& dtypes) {
|
||||
std::vector<Tensor> grad_inputs(dtypes.size());
|
||||
if (!grad.defined()) {
|
||||
return grad_inputs;
|
||||
}
|
||||
bool grad_is_complex = grad.is_complex();
|
||||
for (const auto i : c10::irange(dtypes.size())) {
|
||||
auto gr = grad.select(dim, i);
|
||||
if (grad_is_complex && !at::isComplexType(dtypes[i])) {
|
||||
gr = at::real(gr);
|
||||
}
|
||||
grad_inputs[i] = gr;
|
||||
}
|
||||
return grad_inputs;
|
||||
}
|
||||
|
||||
std::vector<Tensor> block_diag_backward(
|
||||
const Tensor& grad,
|
||||
const std::vector<std::vector<int64_t>>& sizes,
|
||||
|
|
|
|||
|
|
@ -225,6 +225,10 @@ std::vector<at::Tensor> cat_tensors_backward(
|
|||
const std::vector<std::vector<int64_t>>& sizes,
|
||||
const std::vector<ScalarType>& dtypes,
|
||||
int64_t dim);
|
||||
std::vector<at::Tensor> stack_tensors_backward(
|
||||
const at::Tensor& grad,
|
||||
int64_t dim,
|
||||
const std::vector<ScalarType>& dtypes);
|
||||
std::vector<at::Tensor> block_diag_backward(
|
||||
const at::Tensor& grad,
|
||||
const std::vector<std::vector<int64_t>>& sizes,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user