Fixes bug with tolist calls to GradTrackingTensors (#165184)

Fixes #161943

## The Fix
I implemented a recursive unwrapping helper function in the `tensor_to_list.cpp` file that looks for wrapped tensors and unwraps them. The recursive implementation was needed for multi-level gradTrackingTensors.

Let me know if there is any more suggestions on fixing this issue!

@guilhermeleobas @KimbingNg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165184
Approved by: https://github.com/zou3519
This commit is contained in:
Samuel Park 2025-10-15 12:54:28 +00:00 committed by PyTorch MergeBot
parent 5c583e2573
commit f58f301313
2 changed files with 107 additions and 0 deletions

View File

@ -5222,6 +5222,101 @@ class TestCompileTransforms(TestCase):
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
class TestGradTrackingTensorToList(TestCase):
"""Tests for tolist() method with GradTrackingTensor (functorch tensors)."""
def test_tolist_with_grad(self):
"""Test to see if tolist works inside grad transformation."""
def f(x):
# inside grad, x is a GradTrackingTensor
result = x.tolist()
# tolist should return a python list and not fail
self.assertIsInstance(result, list)
self.assertEqual(result, [1.0, 2.0, 3.0])
return (x**2).sum()
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
grad_f = torch.func.grad(f)
result = grad_f(x)
self.assertIsInstance(result, torch.Tensor)
# gradients should still be computed correctly
self.assertEqual(result, [2.0, 4.0, 6.0])
def test_tolist_nested_grad(self):
"""Test `tolist` with nested grad transformations."""
def f(x):
def g(y):
# y is gradTrackingTensor(lvl=1)
inner_list = y.tolist()
self.assertIsInstance(inner_list, list)
return (y**2).sum()
# x is a gradTrackingTensor(lvl=0)
outer_list = x.tolist()
self.assertIsInstance(outer_list, list)
grad_g = torch.func.grad(g)
return grad_g(x).sum()
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
grad_f = torch.func.grad(f)
result = grad_f(x)
# should compute second derivate
self.assertIsInstance(result, torch.Tensor)
# grad_f should return the derivate of g(y) which is (2*x).sum
self.assertEqual(
result,
[
2.0,
2.0,
2.0,
],
)
def test_tolist_multidimensional_grad(self):
"""Test tolist with multi-dimensional tensors in grad."""
def f(x):
result = x.tolist()
self.assertIsInstance(result, list)
self.assertEqual(len(result), 2)
self.assertEqual(result, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
return x.sum()
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], requires_grad=True)
grad_f = torch.func.grad(f)
result = grad_f(x)
self.assertIsInstance(result, torch.Tensor)
self.assertEqual(
result,
[
[
1.0,
1.0,
1.0,
],
[1.0, 1.0, 1.0],
],
)
def test_tolist_conj_neg_grad(self):
"""Test tolist method with conjugate/negative tensors in grad context."""
def f(x):
# test with the conjugate view
x_conj = x.conj()
result_conj = x_conj.tolist()
self.assertIsInstance(result_conj, list)
return (x * x.conj()).real.sum()
x = torch.tensor([1.0 + 2.0j, 3.0 + 4.0j], requires_grad=True)
grad_f = torch.func.grad(f)
result = grad_f(x)
self.assertIsInstance(result, torch.Tensor)
self.assertEqual(result, [2.0 + 4.0j, 6.0 + 8.0j])
only_for = ("cpu", "cuda") only_for = ("cpu", "cuda")
instantiate_device_type_tests( instantiate_device_type_tests(
TestGradTransform, TestGradTransform,
@ -5301,6 +5396,9 @@ instantiate_device_type_tests(
globals(), globals(),
only_for=only_for, only_for=only_for,
) )
instantiate_device_type_tests(
TestGradTrackingTensorToList, globals(), only_for=only_for
)
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()

View File

@ -1,3 +1,4 @@
#include <ATen/functorch/TensorWrapper.h>
#include <torch/csrc/utils/tensor_list.h> #include <torch/csrc/utils/tensor_list.h>
#include <c10/util/irange.h> #include <c10/util/irange.h>
@ -39,6 +40,12 @@ static PyObject* recursive_to_list(
return list.release(); return list.release();
} }
const Tensor& recursive_unwrap(const Tensor& tensor) {
if (auto* wrapper = at::functorch::maybeGetTensorWrapper(tensor))
return recursive_unwrap(wrapper->value());
return tensor;
}
PyObject* tensor_to_list(const Tensor& tensor) { PyObject* tensor_to_list(const Tensor& tensor) {
{ {
py::object pytensor = py::object pytensor =
@ -48,7 +55,9 @@ PyObject* tensor_to_list(const Tensor& tensor) {
".tolist() is not supported for tensor subclasses, got ", ".tolist() is not supported for tensor subclasses, got ",
Py_TYPE(pytensor.ptr())->tp_name); Py_TYPE(pytensor.ptr())->tp_name);
} }
// check if it is a grad tracking tensor and unwrap.
Tensor data = tensor.resolve_conj().resolve_neg(); Tensor data = tensor.resolve_conj().resolve_neg();
data = recursive_unwrap(data);
if (!data.device().is_cpu()) { if (!data.device().is_cpu()) {
pybind11::gil_scoped_release no_gil; pybind11::gil_scoped_release no_gil;
data = data.toBackend(Backend::CPU); data = data.toBackend(Backend::CPU);