mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes bug with tolist calls to GradTrackingTensors (#165184)
Fixes #161943 ## The Fix I implemented a recursive unwrapping helper function in the `tensor_to_list.cpp` file that looks for wrapped tensors and unwraps them. The recursive implementation was needed for multi-level gradTrackingTensors. Let me know if there is any more suggestions on fixing this issue! @guilhermeleobas @KimbingNg Pull Request resolved: https://github.com/pytorch/pytorch/pull/165184 Approved by: https://github.com/zou3519
This commit is contained in:
parent
5c583e2573
commit
f58f301313
|
|
@ -5222,6 +5222,101 @@ class TestCompileTransforms(TestCase):
|
||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGradTrackingTensorToList(TestCase):
|
||||||
|
"""Tests for tolist() method with GradTrackingTensor (functorch tensors)."""
|
||||||
|
|
||||||
|
def test_tolist_with_grad(self):
|
||||||
|
"""Test to see if tolist works inside grad transformation."""
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
# inside grad, x is a GradTrackingTensor
|
||||||
|
result = x.tolist()
|
||||||
|
# tolist should return a python list and not fail
|
||||||
|
self.assertIsInstance(result, list)
|
||||||
|
self.assertEqual(result, [1.0, 2.0, 3.0])
|
||||||
|
return (x**2).sum()
|
||||||
|
|
||||||
|
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
|
||||||
|
grad_f = torch.func.grad(f)
|
||||||
|
result = grad_f(x)
|
||||||
|
self.assertIsInstance(result, torch.Tensor)
|
||||||
|
# gradients should still be computed correctly
|
||||||
|
self.assertEqual(result, [2.0, 4.0, 6.0])
|
||||||
|
|
||||||
|
def test_tolist_nested_grad(self):
|
||||||
|
"""Test `tolist` with nested grad transformations."""
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
def g(y):
|
||||||
|
# y is gradTrackingTensor(lvl=1)
|
||||||
|
inner_list = y.tolist()
|
||||||
|
self.assertIsInstance(inner_list, list)
|
||||||
|
return (y**2).sum()
|
||||||
|
|
||||||
|
# x is a gradTrackingTensor(lvl=0)
|
||||||
|
outer_list = x.tolist()
|
||||||
|
self.assertIsInstance(outer_list, list)
|
||||||
|
grad_g = torch.func.grad(g)
|
||||||
|
return grad_g(x).sum()
|
||||||
|
|
||||||
|
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
|
||||||
|
grad_f = torch.func.grad(f)
|
||||||
|
result = grad_f(x)
|
||||||
|
# should compute second derivate
|
||||||
|
self.assertIsInstance(result, torch.Tensor)
|
||||||
|
# grad_f should return the derivate of g(y) which is (2*x).sum
|
||||||
|
self.assertEqual(
|
||||||
|
result,
|
||||||
|
[
|
||||||
|
2.0,
|
||||||
|
2.0,
|
||||||
|
2.0,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_tolist_multidimensional_grad(self):
|
||||||
|
"""Test tolist with multi-dimensional tensors in grad."""
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
result = x.tolist()
|
||||||
|
self.assertIsInstance(result, list)
|
||||||
|
self.assertEqual(len(result), 2)
|
||||||
|
self.assertEqual(result, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
|
||||||
|
return x.sum()
|
||||||
|
|
||||||
|
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], requires_grad=True)
|
||||||
|
grad_f = torch.func.grad(f)
|
||||||
|
result = grad_f(x)
|
||||||
|
self.assertIsInstance(result, torch.Tensor)
|
||||||
|
self.assertEqual(
|
||||||
|
result,
|
||||||
|
[
|
||||||
|
[
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
],
|
||||||
|
[1.0, 1.0, 1.0],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_tolist_conj_neg_grad(self):
|
||||||
|
"""Test tolist method with conjugate/negative tensors in grad context."""
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
# test with the conjugate view
|
||||||
|
x_conj = x.conj()
|
||||||
|
result_conj = x_conj.tolist()
|
||||||
|
self.assertIsInstance(result_conj, list)
|
||||||
|
return (x * x.conj()).real.sum()
|
||||||
|
|
||||||
|
x = torch.tensor([1.0 + 2.0j, 3.0 + 4.0j], requires_grad=True)
|
||||||
|
grad_f = torch.func.grad(f)
|
||||||
|
result = grad_f(x)
|
||||||
|
self.assertIsInstance(result, torch.Tensor)
|
||||||
|
self.assertEqual(result, [2.0 + 4.0j, 6.0 + 8.0j])
|
||||||
|
|
||||||
|
|
||||||
only_for = ("cpu", "cuda")
|
only_for = ("cpu", "cuda")
|
||||||
instantiate_device_type_tests(
|
instantiate_device_type_tests(
|
||||||
TestGradTransform,
|
TestGradTransform,
|
||||||
|
|
@ -5301,6 +5396,9 @@ instantiate_device_type_tests(
|
||||||
globals(),
|
globals(),
|
||||||
only_for=only_for,
|
only_for=only_for,
|
||||||
)
|
)
|
||||||
|
instantiate_device_type_tests(
|
||||||
|
TestGradTrackingTensorToList, globals(), only_for=only_for
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
#include <ATen/functorch/TensorWrapper.h>
|
||||||
#include <torch/csrc/utils/tensor_list.h>
|
#include <torch/csrc/utils/tensor_list.h>
|
||||||
|
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
|
@ -39,6 +40,12 @@ static PyObject* recursive_to_list(
|
||||||
return list.release();
|
return list.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const Tensor& recursive_unwrap(const Tensor& tensor) {
|
||||||
|
if (auto* wrapper = at::functorch::maybeGetTensorWrapper(tensor))
|
||||||
|
return recursive_unwrap(wrapper->value());
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
PyObject* tensor_to_list(const Tensor& tensor) {
|
PyObject* tensor_to_list(const Tensor& tensor) {
|
||||||
{
|
{
|
||||||
py::object pytensor =
|
py::object pytensor =
|
||||||
|
|
@ -48,7 +55,9 @@ PyObject* tensor_to_list(const Tensor& tensor) {
|
||||||
".tolist() is not supported for tensor subclasses, got ",
|
".tolist() is not supported for tensor subclasses, got ",
|
||||||
Py_TYPE(pytensor.ptr())->tp_name);
|
Py_TYPE(pytensor.ptr())->tp_name);
|
||||||
}
|
}
|
||||||
|
// check if it is a grad tracking tensor and unwrap.
|
||||||
Tensor data = tensor.resolve_conj().resolve_neg();
|
Tensor data = tensor.resolve_conj().resolve_neg();
|
||||||
|
data = recursive_unwrap(data);
|
||||||
if (!data.device().is_cpu()) {
|
if (!data.device().is_cpu()) {
|
||||||
pybind11::gil_scoped_release no_gil;
|
pybind11::gil_scoped_release no_gil;
|
||||||
data = data.toBackend(Backend::CPU);
|
data = data.toBackend(Backend::CPU);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user