mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes bug with tolist calls to GradTrackingTensors (#165184)
Fixes #161943 ## The Fix I implemented a recursive unwrapping helper function in the `tensor_to_list.cpp` file that looks for wrapped tensors and unwraps them. The recursive implementation was needed for multi-level gradTrackingTensors. Let me know if there is any more suggestions on fixing this issue! @guilhermeleobas @KimbingNg Pull Request resolved: https://github.com/pytorch/pytorch/pull/165184 Approved by: https://github.com/zou3519
This commit is contained in:
parent
5c583e2573
commit
f58f301313
|
|
@ -5222,6 +5222,101 @@ class TestCompileTransforms(TestCase):
|
|||
self.assertEqual(actual, expected)
|
||||
|
||||
|
||||
class TestGradTrackingTensorToList(TestCase):
|
||||
"""Tests for tolist() method with GradTrackingTensor (functorch tensors)."""
|
||||
|
||||
def test_tolist_with_grad(self):
|
||||
"""Test to see if tolist works inside grad transformation."""
|
||||
|
||||
def f(x):
|
||||
# inside grad, x is a GradTrackingTensor
|
||||
result = x.tolist()
|
||||
# tolist should return a python list and not fail
|
||||
self.assertIsInstance(result, list)
|
||||
self.assertEqual(result, [1.0, 2.0, 3.0])
|
||||
return (x**2).sum()
|
||||
|
||||
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
|
||||
grad_f = torch.func.grad(f)
|
||||
result = grad_f(x)
|
||||
self.assertIsInstance(result, torch.Tensor)
|
||||
# gradients should still be computed correctly
|
||||
self.assertEqual(result, [2.0, 4.0, 6.0])
|
||||
|
||||
def test_tolist_nested_grad(self):
|
||||
"""Test `tolist` with nested grad transformations."""
|
||||
|
||||
def f(x):
|
||||
def g(y):
|
||||
# y is gradTrackingTensor(lvl=1)
|
||||
inner_list = y.tolist()
|
||||
self.assertIsInstance(inner_list, list)
|
||||
return (y**2).sum()
|
||||
|
||||
# x is a gradTrackingTensor(lvl=0)
|
||||
outer_list = x.tolist()
|
||||
self.assertIsInstance(outer_list, list)
|
||||
grad_g = torch.func.grad(g)
|
||||
return grad_g(x).sum()
|
||||
|
||||
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
|
||||
grad_f = torch.func.grad(f)
|
||||
result = grad_f(x)
|
||||
# should compute second derivate
|
||||
self.assertIsInstance(result, torch.Tensor)
|
||||
# grad_f should return the derivate of g(y) which is (2*x).sum
|
||||
self.assertEqual(
|
||||
result,
|
||||
[
|
||||
2.0,
|
||||
2.0,
|
||||
2.0,
|
||||
],
|
||||
)
|
||||
|
||||
def test_tolist_multidimensional_grad(self):
|
||||
"""Test tolist with multi-dimensional tensors in grad."""
|
||||
|
||||
def f(x):
|
||||
result = x.tolist()
|
||||
self.assertIsInstance(result, list)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||
return x.sum()
|
||||
|
||||
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], requires_grad=True)
|
||||
grad_f = torch.func.grad(f)
|
||||
result = grad_f(x)
|
||||
self.assertIsInstance(result, torch.Tensor)
|
||||
self.assertEqual(
|
||||
result,
|
||||
[
|
||||
[
|
||||
1.0,
|
||||
1.0,
|
||||
1.0,
|
||||
],
|
||||
[1.0, 1.0, 1.0],
|
||||
],
|
||||
)
|
||||
|
||||
def test_tolist_conj_neg_grad(self):
|
||||
"""Test tolist method with conjugate/negative tensors in grad context."""
|
||||
|
||||
def f(x):
|
||||
# test with the conjugate view
|
||||
x_conj = x.conj()
|
||||
result_conj = x_conj.tolist()
|
||||
self.assertIsInstance(result_conj, list)
|
||||
return (x * x.conj()).real.sum()
|
||||
|
||||
x = torch.tensor([1.0 + 2.0j, 3.0 + 4.0j], requires_grad=True)
|
||||
grad_f = torch.func.grad(f)
|
||||
result = grad_f(x)
|
||||
self.assertIsInstance(result, torch.Tensor)
|
||||
self.assertEqual(result, [2.0 + 4.0j, 6.0 + 8.0j])
|
||||
|
||||
|
||||
only_for = ("cpu", "cuda")
|
||||
instantiate_device_type_tests(
|
||||
TestGradTransform,
|
||||
|
|
@ -5301,6 +5396,9 @@ instantiate_device_type_tests(
|
|||
globals(),
|
||||
only_for=only_for,
|
||||
)
|
||||
instantiate_device_type_tests(
|
||||
TestGradTrackingTensorToList, globals(), only_for=only_for
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
#include <ATen/functorch/TensorWrapper.h>
|
||||
#include <torch/csrc/utils/tensor_list.h>
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
|
|
@ -39,6 +40,12 @@ static PyObject* recursive_to_list(
|
|||
return list.release();
|
||||
}
|
||||
|
||||
const Tensor& recursive_unwrap(const Tensor& tensor) {
|
||||
if (auto* wrapper = at::functorch::maybeGetTensorWrapper(tensor))
|
||||
return recursive_unwrap(wrapper->value());
|
||||
return tensor;
|
||||
}
|
||||
|
||||
PyObject* tensor_to_list(const Tensor& tensor) {
|
||||
{
|
||||
py::object pytensor =
|
||||
|
|
@ -48,7 +55,9 @@ PyObject* tensor_to_list(const Tensor& tensor) {
|
|||
".tolist() is not supported for tensor subclasses, got ",
|
||||
Py_TYPE(pytensor.ptr())->tp_name);
|
||||
}
|
||||
// check if it is a grad tracking tensor and unwrap.
|
||||
Tensor data = tensor.resolve_conj().resolve_neg();
|
||||
data = recursive_unwrap(data);
|
||||
if (!data.device().is_cpu()) {
|
||||
pybind11::gil_scoped_release no_gil;
|
||||
data = data.toBackend(Backend::CPU);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user