Let torch dynamo inline torch.func.grad (#118407)

When dynamo sees torch.func.grad, it tries to inline all frames related
to.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118407
Approved by: https://github.com/zou3519
This commit is contained in:
Guilherme Leobas 2024-02-28 13:14:37 -03:00 committed by PyTorch MergeBot
parent 5472923998
commit 491c2b4665
36 changed files with 716 additions and 310 deletions

View File

@ -26,6 +26,7 @@ from torch._dynamo.testing import (
from torch._dynamo.utils import counters, ifdynstaticdefault
from torch._higher_order_ops.wrap import wrap
from torch.testing._internal.common_utils import (
munge_exc,
TEST_WITH_TORCHDYNAMO,
xfailIfTorchDynamo,
)
@ -2259,9 +2260,7 @@ class GraphModule(torch.nn.Module):
actual_stack = self._get_source_fn_stack(gm, {"sum_1", "sin"})
self.assertExpectedInline(
pprint.pformat(actual_stack),
"""\
{'sin': ['grad_impl', 'grad_impl', 'sin'],
'sum_1': ['grad_impl', 'grad_impl', 'sum_1']}""",
"""{'sin': ['sin']}""",
)
@config.patch(capture_func_transforms=True)
@ -2377,6 +2376,59 @@ def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytre
class HigherOrderOpVmapGuardTests(LoggingTestCase):
@config.patch(capture_func_transforms=True)
@make_logging_test(recompiles=True)
def test_vmap_grad_guard_ok(self, records):
vmap = torch.vmap
grad = torch.func.grad
def g(x):
return vmap(grad(torch.sin))(x)
@torch.compile(backend="eager")
def fn(x):
return vmap(g)(x)
x = torch.randn(4, 5)
y = fn(x)
# sanity check
self.assertEqual(len(records), 0)
self.assertEqual(x.cos(), y)
# Calling the same function again won't have any effect on guards
fn(x)
self.assertEqual(len(records), 0)
@xfailIfTorchDynamo
@config.patch(capture_func_transforms=True)
@make_logging_test(recompiles=True)
def test_grad_guard_fail(self, records):
grad = torch.func.grad
@torch.compile(backend="eager")
def fn(x):
return grad(torch.sin)(x.sum())
x = torch.randn([])
fn(x)
self.assertEqual(len(records), 0)
# calling again should not invalidate the graph
fn(x)
self.assertEqual(len(records), 0)
# call grad should retrigger compilation
x = torch.randn(3)
grad(fn)(x)
self.assertGreater(len(records), 0)
record = self.getRecord(records, "pyfunctorch")
self.assertIn(
"""\
triggered by the following guard failure(s):
- torch._functorch.pyfunctorch.compare_functorch_state([])""",
munge_exc(record.getMessage()),
)
@config.patch(capture_func_transforms=True)
@make_logging_test(recompiles=True)
def test_vmap_guard_ok(self, records):
@ -2452,6 +2504,59 @@ class HigherOrderOpVmapGuardTests(LoggingTestCase):
record.getMessage(),
)
@xfailIfTorchDynamo
@config.patch(capture_func_transforms=True)
@make_logging_test(recompiles=True)
def test_vmap_grad_vmap_guard_fail(self, records):
vmap = torch.vmap
grad = torch.func.grad
def g(x):
y = vmap(torch.sin, randomness="same")(x)
return y.sum(0)
@torch.compile(backend="eager")
def fn(x):
return grad(g)(x)
x = torch.randn(3, 3)
y = vmap(fn, randomness="error")(x)
self.assertEqual(x.cos(), y)
# previous FX graph should be invalidated
x = torch.randn(3, 3, 4)
y = vmap(vmap(fn, randomness="different"))(x)
self.assertGreater(len(records), 0)
record = self.getRecord(records, "pyfunctorch")
self.assertIn(
"""\
triggered by the following guard failure(s):
- torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'error')])""",
munge_exc(record.getMessage()),
)
@xfailIfTorchDynamo
@config.patch(capture_func_transforms=True)
@make_logging_test(recompiles=True)
def test_vmap_recompile_different_states(self, records):
@torch.compile(backend="eager")
def fn(x):
return torch.vmap(lambda x: x.sin())(x)
x = torch.zeros(3, 3, 4, 5)
y = torch.vmap(fn, randomness="same")(x)
self.assertEqual(len(records), 0) # sanity check
y = torch.vmap(fn, randomness="different")(x)
self.assertGreater(len(records), 0)
record = self.getRecord(records, "pyfunctorch")
self.assertIn(
"""\
triggered by the following guard failure(s):
- torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'same')])""",
munge_exc(record.getMessage()),
)
class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
def tearDown(self):
@ -2510,19 +2615,32 @@ class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
child = L_x_
grad_body_0 = self.grad_body_0
grad_proxy = torch.func.grad(grad_body_0, 0, False); grad_body_0 = None
call = grad_proxy.__call__(l_x_); grad_proxy = l_x_ = None
contiguous = call.contiguous(); call = None
return (contiguous,)
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
class GraphModule(torch.nn.Module):
def forward(self, l_x_):
sin = l_x_.sin(); l_x_ = None
sum_1 = sin.sum(); sin = None
return sum_1
diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
_tensor_requires_grad = torch._functorch.eager_transforms._tensor_requires_grad(diff_args)
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
sin = diff_args.sin()
output = sin.sum(); sin = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
grad_input = _autograd_grad[0]; _autograd_grad = None
grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
_ = torch._C._functorch._unwrap_for_grad(output, 1); output = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (grad,)
""",
)
@ -2566,20 +2684,33 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
child = L_x_
grad_body_0 = self.grad_body_0
grad_proxy = torch.func.grad(grad_body_0, 0, False); grad_body_0 = None
call = grad_proxy.__call__(l_x_); grad_proxy = l_x_ = None
contiguous = call.contiguous(); call = None
return (contiguous,)
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
class GraphModule(torch.nn.Module):
def forward(self, l_x_):
sin = l_x_.sin(); l_x_ = None
add = sin + 3; sin = None
sum_1 = add.sum(); add = None
return sum_1
diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
_tensor_requires_grad = torch._functorch.eager_transforms._tensor_requires_grad(diff_args)
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
sin = diff_args.sin()
add = sin + 3; sin = None
output = add.sum(); add = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
grad_input = _autograd_grad[0]; _autograd_grad = None
grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
_ = torch._C._functorch._unwrap_for_grad(output, 1); output = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (grad,)
""",
)
@ -2609,22 +2740,35 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
child = L_x_
y = torch.randn(3)
grad_body_0 = self.grad_body_0
grad_proxy = torch.func.grad(grad_body_0, 0, False); grad_body_0 = None
call = grad_proxy.__call__(l_x_, y); grad_proxy = l_x_ = None
contiguous = call.contiguous(); call = None
return (y, contiguous)
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
class GraphModule(torch.nn.Module):
def forward(self, l_x_, y):
sin = l_x_.sin(); l_x_ = None
add = sin + y; sin = y = None
sum_1 = add.sum(); add = None
return sum_1
diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
_tensor_requires_grad = torch._functorch.eager_transforms._tensor_requires_grad(diff_args)
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
sin = diff_args.sin()
add = sin + y; sin = None
output = add.sum(); add = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
grad_input = _autograd_grad[0]; _autograd_grad = None
grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
_ = torch._C._functorch._unwrap_for_grad(output, 1); output = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (grad, y)
""",
)
@ -2656,20 +2800,33 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
child = L_x_
grad_body_0 = self.grad_body_0
grad_proxy = torch.func.grad(grad_body_0, 0, False); grad_body_0 = None
call = grad_proxy.__call__(l_x_); grad_proxy = l_x_ = None
contiguous = call.contiguous(); call = None
return (contiguous,)
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
class GraphModule(torch.nn.Module):
def forward(self, l_x_):
sin = l_x_.sin(); l_x_ = None
add = sin + 3.14; sin = None
sum_1 = add.sum(); add = None
return sum_1
diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
_tensor_requires_grad = torch._functorch.eager_transforms._tensor_requires_grad(diff_args)
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
sin = diff_args.sin()
add = sin + 3.14; sin = None
output = add.sum(); add = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
grad_input = _autograd_grad[0]; _autograd_grad = None
grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
_ = torch._C._functorch._unwrap_for_grad(output, 1); output = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (grad,)
""",
)
@ -2698,23 +2855,36 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
child = L_x_
grad_body_0 = self.grad_body_0
grad_proxy = torch.func.grad(grad_body_0, 0, True); grad_body_0 = None
call = grad_proxy.__call__(l_x_); grad_proxy = l_x_ = None
getitem = call[0]
getitem_1 = call[1]; call = None
contiguous = getitem.contiguous(); getitem = None
return (contiguous, getitem_1)
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
class GraphModule(torch.nn.Module):
def forward(self, l_x_):
sin = l_x_.sin()
add = sin + 3.14; sin = None
sum_1 = add.sum(); add = None
cos = l_x_.cos(); l_x_ = None
return (sum_1, cos)
diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
_tensor_requires_grad = torch._functorch.eager_transforms._tensor_requires_grad(diff_args)
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
sin = diff_args.sin()
add = sin + 3.14; sin = None
output = add.sum(); add = None
aux = diff_args.cos()
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
grad_input = _autograd_grad[0]; _autograd_grad = None
grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
_ = torch._C._functorch._unwrap_for_grad(output, 1); output = None
aux_2 = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (grad, aux_2)
""",
)
@ -2742,24 +2912,38 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
l_x_ = L_x_
l_y_ = L_y_
child = L_x_
child_1 = L_y_
grad_body_0 = self.grad_body_0
grad_proxy = torch.func.grad(grad_body_0, 0, True); grad_body_0 = None
call = grad_proxy.__call__(l_x_, l_y_); grad_proxy = l_x_ = l_y_ = None
getitem = call[0]
getitem_1 = call[1]; call = None
contiguous = getitem.contiguous(); getitem = None
return (contiguous, getitem_1)
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
class GraphModule(torch.nn.Module):
def forward(self, l_x_, l_y_):
sin = l_x_.sin()
add = sin + l_y_; sin = l_y_ = None
sum_1 = add.sum(); add = None
cos = l_x_.cos(); l_x_ = None
return (sum_1, cos)
diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None
_wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(child_1, 1); child_1 = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
_tensor_requires_grad = torch._functorch.eager_transforms._tensor_requires_grad(diff_args)
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
sin = diff_args.sin()
add = sin + _wrap_for_grad_1; sin = _wrap_for_grad_1 = None
output = add.sum(); add = None
aux = diff_args.cos()
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
grad_input = _autograd_grad[0]; _autograd_grad = None
grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
_ = torch._C._functorch._unwrap_for_grad(output, 1); output = None
aux_2 = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (grad, aux_2)
""",
)
@ -2798,27 +2982,45 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
l_x_ = L_x_
l_y_ = L_y_
child = L_x_
child_1 = L_y_
grad_body_0 = self.grad_body_0
grad_proxy = torch.func.grad(grad_body_0, (0, 1), True); grad_body_0 = None
call = grad_proxy.__call__(l_x_, l_y_); grad_proxy = l_x_ = l_y_ = None
getitem = call[0]
getitem_1 = getitem[0]
getitem_2 = getitem[1]; getitem = None
getitem_3 = call[1]; call = None
contiguous = getitem_1.contiguous(); getitem_1 = None
contiguous_1 = getitem_2.contiguous(); getitem_2 = None
return (contiguous, contiguous_1, getitem_3)
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
class GraphModule(torch.nn.Module):
def forward(self, l_x_, l_y_):
sin = l_x_.sin()
add = sin + l_y_; sin = l_y_ = None
sum_1 = add.sum(); add = None
cos = l_x_.cos(); l_x_ = None
return (sum_1, cos)
child_4 = torch._C._functorch._wrap_for_grad(child, 1); child = None
child_5 = torch._C._functorch._wrap_for_grad(child_1, 1); child_1 = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
_tensor_requires_grad = torch._functorch.eager_transforms._tensor_requires_grad(child_4)
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True)
_tensor_requires_grad_1 = torch._functorch.eager_transforms._tensor_requires_grad(child_5)
set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
sin = child_4.sin()
add = sin + child_5; sin = None
output = add.sum(); add = None
aux = child_4.cos()
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child_4, child_5], create_graph = True); child_4 = child_5 = None
child_6 = _autograd_grad[0]
child_7 = _autograd_grad[1]; _autograd_grad = None
_unwrap_for_grad = torch._C._functorch._unwrap_for_grad(child_6, 1); child_6 = None
_unwrap_for_grad_1 = torch._C._functorch._unwrap_for_grad(child_7, 1); child_7 = None
_ = torch._C._functorch._unwrap_for_grad(output, 1); output = None
aux_2 = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (_unwrap_for_grad, _unwrap_for_grad_1, aux_2)
""",
)
self.assertExpectedInline(
@ -2826,27 +3028,45 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
l_x_ = L_x_
l_y_ = L_y_
child = L_x_
child_1 = L_y_
grad_body_0 = self.grad_body_0
grad_proxy = torch.func.grad(grad_body_0, (0, 1), True); grad_body_0 = None
call = grad_proxy.__call__(l_x_, l_y_); grad_proxy = l_x_ = l_y_ = None
getitem = call[0]
getitem_1 = getitem[0]
getitem_2 = getitem[1]; getitem = None
getitem_3 = call[1]; call = None
contiguous = getitem_1.contiguous(); getitem_1 = None
contiguous_1 = getitem_2.contiguous(); getitem_2 = None
return (contiguous, contiguous_1, getitem_3)
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
class GraphModule(torch.nn.Module):
def forward(self, l_x_, l_y_):
sin = l_x_.sin()
add = sin + l_y_; sin = l_y_ = None
sum_1 = add.sum(); add = None
cos = l_x_.cos(); l_x_ = None
return (sum_1, cos)
child_4 = torch._C._functorch._wrap_for_grad(child, 1); child = None
child_5 = torch._C._functorch._wrap_for_grad(child_1, 1); child_1 = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
_tensor_requires_grad = torch._functorch.eager_transforms._tensor_requires_grad(child_4)
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True)
_tensor_requires_grad_1 = torch._functorch.eager_transforms._tensor_requires_grad(child_5)
set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
sin = child_4.sin()
add = sin + child_5; sin = None
output = add.sum(); add = None
aux = child_4.cos()
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child_4, child_5], create_graph = True); child_4 = child_5 = None
child_6 = _autograd_grad[0]
child_7 = _autograd_grad[1]; _autograd_grad = None
_unwrap_for_grad = torch._C._functorch._unwrap_for_grad(child_6, 1); child_6 = None
_unwrap_for_grad_1 = torch._C._functorch._unwrap_for_grad(child_7, 1); child_7 = None
_ = torch._C._functorch._unwrap_for_grad(output, 1); output = None
aux_2 = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (_unwrap_for_grad, _unwrap_for_grad_1, aux_2)
""",
)
@ -2872,27 +3092,52 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
child = L_x_
grad_body_1 = self.grad_body_1
grad_proxy = torch.func.grad(grad_body_1, 0, False); grad_body_1 = None
call = grad_proxy.__call__(l_x_); grad_proxy = l_x_ = None
contiguous = call.contiguous(); call = None
return (contiguous,)
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
class GraphModule(torch.nn.Module):
def forward(self, l_x_):
grad_body_0 = self.grad_body_0
grad_proxy = torch.func.grad(grad_body_0, 0, False); grad_body_0 = None
call = grad_proxy.__call__(l_x_); grad_proxy = l_x_ = None
contiguous = call.contiguous(); call = None
return contiguous
child_1 = torch._C._functorch._wrap_for_grad(child, 1); child = None
class GraphModule(torch.nn.Module):
def forward(self, l_x_):
sin = l_x_.sin(); l_x_ = None
sum_1 = sin.sum(); sin = None
return sum_1
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
_tensor_requires_grad = torch._functorch.eager_transforms._tensor_requires_grad(child_1)
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
_saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_grad_increment_nesting_1 = torch._C._functorch._grad_increment_nesting()
diff_args_1 = torch._C._functorch._wrap_for_grad(child_1, 2)
set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True)
_tensor_requires_grad_1 = torch._functorch.eager_transforms._tensor_requires_grad(diff_args_1)
set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
sin = diff_args_1.sin()
output = sin.sum(); sin = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args_1], create_graph = True); diff_args_1 = None
grad_input = _autograd_grad[0]; _autograd_grad = None
output_2 = torch._C._functorch._unwrap_for_grad(grad_input, 2); grad_input = None
_ = torch._C._functorch._unwrap_for_grad(output, 2); output = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
_saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_autograd_grad_1 = torch._functorch.eager_transforms._autograd_grad((output_2,), [child_1], create_graph = True); child_1 = None
grad_input_2 = _autograd_grad_1[0]; _autograd_grad_1 = None
grad_1 = torch._C._functorch._unwrap_for_grad(grad_input_2, 1); grad_input_2 = None
__1 = torch._C._functorch._unwrap_for_grad(output_2, 1); output_2 = None
_grad_decrement_nesting_1 = torch._C._functorch._grad_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (grad_1,)
""",
)
@ -2929,14 +3174,7 @@ class GraphModule(torch.nn.Module):
x = torch.randn(3, 3, 3)
actual = wrapper_fn(x)
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
self.assertEqual(len(counters["graph_break"]), 1)
assert_dict_matches_regex(
self,
dict(counters["graph_break"]),
{
r".*HigherOrderOperator: Mutating a variable not in the current scope \(SideEffects\)": 2
},
)
self.assertEqual(len(counters["graph_break"]), 0)
self.assertEqual(actual, expected)
@config.patch(capture_func_transforms=True)
@ -2956,12 +3194,7 @@ class GraphModule(torch.nn.Module):
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(
(x1, x2)
)
self.assertEqual(len(counters["graph_break"]), 1)
assert_dict_matches_regex(
self,
dict(counters["graph_break"]),
{".*torch.func.grad with body that accepts non-Tensors as input": 2},
)
self.assertEqual(len(counters["graph_break"]), 0)
self.assertEqual(actual, expected)
@config.patch(capture_func_transforms=True)
@ -2988,20 +3221,33 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
child = L_x_
grad_body_0 = self.grad_body_0
grad_proxy = torch.func.grad(grad_body_0, 0, False); grad_body_0 = None
call = grad_proxy.__call__(l_x_, 3.0); grad_proxy = l_x_ = None
contiguous = call.contiguous(); call = None
return (contiguous,)
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
class GraphModule(torch.nn.Module):
def forward(self, l_x_, const):
sin = l_x_.sin(); l_x_ = None
sum_1 = sin.sum(); sin = None
add = sum_1 + 3.0; sum_1 = None
return add
diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
_tensor_requires_grad = torch._functorch.eager_transforms._tensor_requires_grad(diff_args)
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
sin = diff_args.sin()
sum_1 = sin.sum(); sin = None
output = sum_1 + 3.0; sum_1 = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
grad_input = _autograd_grad[0]; _autograd_grad = None
grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
_ = torch._C._functorch._unwrap_for_grad(output, 1); output = None
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (grad,)
""",
)
@ -3045,11 +3291,7 @@ class GraphModule(torch.nn.Module):
y = torch.randn(3, 3)
actual = wrapper_fn(x, y)
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y)
self.assertEqual(len(counters["graph_break"]), 1)
self.assertEqual(
dict(counters["graph_break"]),
{"torch.func.grad: kwargs arguments are currently unsupported.": 2},
)
self.assertEqual(len(counters["graph_break"]), 0)
self.assertEqual(actual, expected)
@config.patch(capture_func_transforms=True)
@ -3064,6 +3306,7 @@ class GraphModule(torch.nn.Module):
# should not recompile on second call. See Pytorch issue #118493
y = torch.vmap(fn)(x)
@xfailIfTorchDynamo
@config.patch(capture_func_transforms=True)
@config.patch(error_on_recompile=True)
def test_vmap_recompile_different_config(self):
@ -3100,6 +3343,18 @@ class GraphModule(torch.nn.Module):
with self.assertRaises(torch._dynamo.exc.RecompileError):
torch.vmap(fn, randomness="different")(x)
@config.patch(capture_func_transforms=True)
@config.patch(error_on_recompile=True)
def test_grad_recompile(self):
@torch.compile(backend="eager")
def fn(x):
return torch.func.grad(torch.sin)(x)
x = torch.randn([])
torch.func.grad(fn)(x)
# should not recompile on second call
torch.func.grad(fn)(x)
@config.patch(capture_func_transforms=True)
def test_vmap_get_wrapped(self):
counters.clear()

View File

@ -1670,7 +1670,7 @@ class TestJac(TestCase):
self.assertEqual(jacrev(func)(x), torch.autograd.functional.jacobian(func, x))
@FIXME_jacrev_only
@jacrev_and_jacfwd
def test_diff_numel(self, device, jacapi):
x = torch.randn(2, 4, device=device)
@ -1687,14 +1687,14 @@ class TestJac(TestCase):
expected[2, 0, 0, 3] = 1
self.assertEqual(y, expected)
@FIXME_jacrev_only
@jacrev_and_jacfwd
def test_vmap_on_jac_simple(self, device, jacapi):
x = torch.randn(2, 3, device=device)
y = vmap(jacapi(torch.sin))(x)
expected = torch.stack([torch.diagflat(x[i].cos()) for i in range(2)])
assert torch.allclose(y, expected)
@FIXME_jacrev_only
@jacrev_and_jacfwd
def test_nested_jac_simple(self, device, jacapi):
def foo(x):
return x.sin().sum()
@ -1755,7 +1755,7 @@ class TestJac(TestCase):
self.assertTrue(isinstance(z[0], tuple))
self.assertEqual(z, ((expected_out0_x,), (expected_out1_x,)))
@FIXME_jacrev_only
@jacrev_and_jacfwd
def test_multiple_outputs_pytree(self, device, jacapi):
def f(x, y):
return {'left': 2 * x + 3 * y, 'right': 4 * x + 5 * y}
@ -1816,7 +1816,7 @@ class TestJac(TestCase):
self.assertEqual(result.dim(), 2)
self.assertEqual(result, x.new_ones(1, 1))
@FIXME_jacrev_only
@jacrev_and_jacfwd
def test_aux_tensor(self, device, jacapi):
def f(x):
y = x.clone()
@ -1912,7 +1912,7 @@ class TestJac(TestCase):
)
self.assertEqual(result, expected)
@FIXME_jacrev_only
@jacrev_and_jacfwd
def test_multiple_inputs_outputs_pytree_multidim(self, device, jacapi):
def f(dct):
a = dct['a']
@ -4806,6 +4806,7 @@ class TestCompileTransforms(TestCase):
actual = wrapper_fn(x, y)
expected = torch.compile(wrapper_fn, backend='eager', fullgraph=True)(x, y)
fn = torch.compile(wrapper_fn, backend='eager', fullgraph=True)
self.assertEqual(actual, expected)
def wrapper_fn(x, y):

View File

@ -24,6 +24,10 @@ def set_single_level_autograd_function_allowed(allowed: bool) -> None: ...
def get_single_level_autograd_function_allowed() -> bool: ...
def _unwrap_functional_tensor(tensor: Tensor, reapply_views: bool) -> Tensor: ...
def _wrap_functional_tensor(tensor: Tensor, level: int) -> Tensor: ...
def _vmap_increment_nesting(batch_size: int, randomness: str) -> int: ...
def _vmap_decrement_nesting() -> int: ...
def _grad_increment_nesting() -> int: ...
def _grad_decrement_nesting() -> int: ...
# Defined in aten/src/ATen/functorch/Interpreter.h
class TransformType(Enum):

View File

@ -2109,7 +2109,7 @@ class InstructionTranslator(InstructionTranslatorBase):
speculation_log=speculation_log,
)
self._throw_if_in_vmap()
self._throw_if_in_functorch()
# as soon as we create the tracing context we should keep it active, so any calls
# into dynamo apis can rely on finding it
@ -2147,20 +2147,21 @@ class InstructionTranslator(InstructionTranslatorBase):
if name in f_locals:
self._freevars_ids[name] = id(f_locals[name])
def _throw_if_in_vmap(self):
def _throw_if_in_functorch(self):
# Fallback to eager in case of a graph break inside vmap
eager = torch._dynamo.lookup_backend("eager")
compiler_fn = inspect.getattr_static(
self.output.compiler_fn, "compiler_fn", self.output.compiler_fn
)
ci = torch._C._functorch.peek_interpreter_stack()
if (
ci is not None
and ci.key() == torch._C._functorch.TransformType.Vmap
and compiler_fn is not eager
):
# if it reaches here, it means Dynamo failed to inline vmap
msg = "torch.vmap(fn) requires the function to be inlined by dynamo"
forbidden_keys = (
torch._C._functorch.TransformType.Vmap,
torch._C._functorch.TransformType.Grad,
)
if ci is not None and ci.key() in forbidden_keys and compiler_fn is not eager:
# if it reaches here, it means Dynamo failed to inline a functorch function
name = ci.key().name.lower()
msg = f"torch.func.{name}(fn) requires the function to be inlined by dynamo"
unimplemented(msg)
def get_example_value(self, source: Source):

View File

@ -49,7 +49,7 @@ from .utils import getfile, hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper
from .variables import (
BuiltinVariable,
FunctorchVmapHigherOrderVariable,
FunctorchHigherOrderVariable,
NestedUserFunctionVariable,
SkipFunctionVariable,
TorchInGraphFunctionVariable,
@ -155,7 +155,7 @@ manual_torch_name_rule_map = {
"torch.resize_as_": SkipFunctionVariable,
"torch.resize_as_sparse_": SkipFunctionVariable,
"torch.get_default_device": TorchInGraphFunctionVariable,
# functorch
# functorch/vmap
"torch._functorch.vmap._check_int_or_none": UserFunctionVariable,
"torch._functorch.vmap._check_out_dims_is_int_or_int_pytree": UserFunctionVariable,
"torch._functorch.vmap._check_randomness_arg": UserFunctionVariable,
@ -178,8 +178,20 @@ manual_torch_name_rule_map = {
"torch._functorch.vmap.restore_vmap": UserFunctionVariable,
"torch._functorch.apis.vmap": UserFunctionVariable,
"torch._functorch.vmap.unwrap_batched": UserFunctionVariable,
"torch._functorch.vmap.vmap_impl": FunctorchVmapHigherOrderVariable,
"torch._functorch.vmap.vmap_impl": FunctorchHigherOrderVariable,
"torch._functorch.vmap.wrap_batched": UserFunctionVariable,
# functorch/grad
"torch._functorch.eager_transforms.grad_impl": FunctorchHigherOrderVariable,
"torch._functorch.apis.grad_and_value": UserFunctionVariable,
"torch._functorch.eager_transforms._as_tuple": UserFunctionVariable,
"torch._functorch.eager_transforms._check_unique_non_empty": UserFunctionVariable,
"torch._functorch.eager_transforms._create_differentiable": UserFunctionVariable,
"torch._functorch.eager_transforms._slice_argnums": UserFunctionVariable,
"torch._functorch.eager_transforms._undo_create_differentiable": UserFunctionVariable,
"torch._functorch.eager_transforms._validate_and_wrap_argnum": UserFunctionVariable,
"torch._functorch.eager_transforms._validate_and_wrap_argnums": UserFunctionVariable,
"torch._functorch.eager_transforms._wrap_all_tensors": UserFunctionVariable,
"torch._functorch.eager_transforms._wrap_tensor_for_grad": UserFunctionVariable,
"torch._constrain_as_size": UserFunctionVariable,
"torch._constrain_as_value": UserFunctionVariable,
"torch._tensor._convert": UserFunctionVariable,
@ -187,6 +199,8 @@ manual_torch_name_rule_map = {
"torch.backends.mha.get_fastpath_enabled": UserFunctionVariable,
"torch._C._functorch._add_batch_dim": TorchInGraphFunctionVariable,
"torch._C._functorch._remove_batch_dim": TorchInGraphFunctionVariable,
"torch._C._functorch._wrap_for_grad": TorchInGraphFunctionVariable,
"torch._C._functorch._unwrap_for_grad": TorchInGraphFunctionVariable,
"torch._C._functorch.is_batchedtensor": TorchInGraphFunctionVariable,
"torch._dynamo.mark_static": UserFunctionVariable,
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
@ -2132,36 +2146,25 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch._functorch.deprecated.vjp",
"torch._functorch.deprecated.warn_deprecated",
"torch._functorch.eager_transforms._any_differentiable",
"torch._functorch.eager_transforms._as_tuple",
"torch._functorch.eager_transforms._autograd_grad",
"torch._functorch.eager_transforms._check_unique_non_empty",
"torch._functorch.eager_transforms._chunked_standard_basis_for_",
"torch._functorch.eager_transforms._construct_standard_basis_for",
"torch._functorch.eager_transforms._create_differentiable",
"torch._functorch.eager_transforms._tensor_requires_grad",
"torch._functorch.eager_transforms._is_differentiable",
"torch._functorch.eager_transforms._jvp_with_argnums",
"torch._functorch.eager_transforms._maybe_unwrap_functional_tensor",
"torch._functorch.eager_transforms._maybe_wrap_functional_tensor",
"torch._functorch.eager_transforms._replace_args",
"torch._functorch.eager_transforms._safe_zero_index",
"torch._functorch.eager_transforms._slice_argnums",
"torch._functorch.eager_transforms._undo_create_differentiable",
"torch._functorch.eager_transforms._unwrap_all_tensors_from_functional",
"torch._functorch.eager_transforms._validate_and_wrap_argnum",
"torch._functorch.eager_transforms._validate_and_wrap_argnums",
"torch._functorch.eager_transforms._vjp_with_argnums",
"torch._functorch.eager_transforms._wrap_all_tensors_to_functional",
"torch._functorch.eager_transforms._wrap_all_tensors",
"torch._functorch.eager_transforms._wrap_tensor_for_grad",
"torch._functorch.eager_transforms.assert_flat_tuple_of_tensors",
"torch._functorch.eager_transforms.assert_non_empty_list_of_tensors",
"torch._functorch.eager_transforms.assert_non_empty_tensor_output",
"torch._functorch.eager_transforms.assert_output_is_tensor_or_tensors",
"torch._functorch.eager_transforms.enable_inplace_requires_grad",
"torch._functorch.eager_transforms.error_if_complex",
"torch._functorch.eager_transforms.functionalize",
"torch._functorch.eager_transforms.grad_and_value",
"torch._functorch.eager_transforms.grad_impl",
"torch._functorch.eager_transforms.hessian",
"torch._functorch.eager_transforms.jacfwd",
"torch._functorch.eager_transforms.jacrev",
@ -3171,6 +3174,7 @@ MOD_INLINELIST = {
"torch._dynamo.comptime",
"torch._dynamo.polyfill",
"torch._functorch.vmap",
"torch._functorch.eager_transforms",
"torch._inductor.test_operators",
"torch.amp.autocast_mode",
"torch.ao.nn",
@ -3355,7 +3359,7 @@ def check_verbose(obj, is_inlined_call=False):
rule = torch._dynamo.trace_rules.lookup_inner(
fi.py_obj, fi.name, fi.filename, is_inlined_call
)
if rule in [UserFunctionVariable, FunctorchVmapHigherOrderVariable]:
if rule in [UserFunctionVariable, FunctorchHigherOrderVariable]:
return SkipResult(
False,
"inlined according trace_rules.lookup",

View File

@ -813,7 +813,9 @@ def skip_frame_if_in_functorch_mode(val: torch.Tensor):
@contextmanager
def preserve_rng_state():
with torch.utils._python_dispatch._disable_current_modes():
disable_functorch = torch._C._DisableFuncTorch
disable_current_modes = torch.utils._python_dispatch._disable_current_modes
with disable_current_modes(), disable_functorch():
rng_state = torch.clone(torch.random.get_rng_state())
skip_frame_if_in_functorch_mode(rng_state)
if torch.cuda.is_available():

View File

@ -7,6 +7,8 @@ from .ctx_manager import (
ContextWrappingVariable,
DeterministicAlgorithmsVariable,
DisabledSavedTensorsHooksVariable,
GradIncrementNestingCtxManagerVariable,
GradInplaceRequiresGradCtxManagerVariable,
GradModeVariable,
InferenceModeVariable,
StreamContextVariable,
@ -29,7 +31,7 @@ from .functions import (
UserMethodVariable,
)
from .higher_order_ops import (
FunctorchVmapHigherOrderVariable,
FunctorchHigherOrderVariable,
TorchHigherOrderOperatorVariable,
)
from .iter import (

View File

@ -1510,6 +1510,8 @@ def wrap_fx_proxy_cls(
torch._C._functorch._vmap_increment_nesting,
torch._C._functorch._vmap_decrement_nesting,
torch._functorch.vmap._validate_and_get_batch_size,
torch._C._functorch._grad_increment_nesting,
torch._C._functorch._grad_decrement_nesting,
# some mac builds are missing torch.distributed.get_rank()
getattr(torch.distributed, "get_rank", _missing),
getattr(torch.distributed, "get_world_size", _missing),

View File

@ -149,6 +149,85 @@ class GenericContextWrappingVariable(ContextWrappingVariable):
return x
class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
"""represents torch grad requries grad"""
@staticmethod
def create(tx, target_values, **kwargs):
return GradInplaceRequiresGradCtxManagerVariable(
target_values=target_values,
initial_values=None,
**kwargs,
)
def enter(self, tx):
[enabled] = self.target_values
self.prev_state = torch._C._functorch.get_inplace_requires_grad_allowed()
torch._C._functorch.set_inplace_requires_grad_allowed(enabled)
self.set_cleanup_hook(
tx,
lambda: torch._C._functorch.set_inplace_requires_grad_allowed(
self.prev_state
),
)
self.state.proxy = tx.output.create_node(
"call_function",
torch._C._functorch.set_inplace_requires_grad_allowed,
(enabled,),
{},
)
return variables.ConstantVariable.create(None)
def exit(self, tx, *args):
self.state.cleanup()
tx.output.create_node(
"call_function",
torch._C._functorch.set_inplace_requires_grad_allowed,
(self.prev_state,),
{},
)
return variables.ConstantVariable.create(None)
class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable):
"""represents torch.func.grad increment/decrement nesting"""
# A guard is needed as the grad level is baked into the torch FX graph
# This is fine if grad is only called from within the function
# being compiled. But the FX graph may be invalid in the case of a grad
# call from eager that calls the compiled function, as the grad levels
# may be different.
_guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH)
@staticmethod
def create(tx, **kwargs):
var = GradIncrementNestingCtxManagerVariable(
target_values=None,
initial_values=None,
**kwargs,
)
return var
def enter(self, tx):
install_guard(self._guards_singleton)
grad_level = torch._C._functorch._grad_increment_nesting()
self.set_cleanup_hook(tx, lambda: torch._C._functorch._grad_decrement_nesting())
self.state.proxy = tx.output.create_node(
"call_function",
torch._C._functorch._grad_increment_nesting,
(),
{},
)
return variables.ConstantVariable.create(grad_level)
def exit(self, tx, *args):
self.state.cleanup()
tx.output.create_node(
"call_function", torch._C._functorch._grad_decrement_nesting, (), {}
)
return variables.ConstantVariable.create(None)
class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable):
"""represents torch VMap increment/decrement nesting"""

View File

@ -1180,13 +1180,16 @@ class FunctorchGradHigherOrderVariable(TorchHigherOrderOperatorVariable):
return TupleVariable([TupleVariable(items), aux])
class FunctorchVmapHigherOrderVariable(UserFunctionVariable):
class FunctorchHigherOrderVariable(UserFunctionVariable):
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
if not torch._dynamo.config.capture_func_transforms:
name = self.get_name()
assert name in ("grad_impl", "vmap_impl")
fn = name.split("_")[0]
unimplemented(
"torch.func.vmap capture is disabled, "
f"torch.func.{fn} capture is disabled, "
"it can be turned on by setting "
"`torch._dynamo.config.capture_func_transforms=True`"
)

View File

@ -43,7 +43,6 @@ from .ctx_manager import (
TorchFunctionDisableVariable,
)
from .distributed import is_constant_pg_functions, is_from_local, ProcessGroupVariable
from .higher_order_ops import TorchHigherOrderOperatorVariable
from .lists import ListVariable, TupleVariable
from .torch_function import can_dispatch_torch_function, dispatch_torch_function
@ -55,6 +54,8 @@ supported_ctx_manager_classes = {
torch.autograd.profiler.record_function,
torch._C.DisableTorchFunctionSubclass,
torch._functorch.vmap.vmap_increment_nesting,
torch._functorch.eager_transforms.grad_increment_nesting,
torch._functorch.eager_transforms.enable_inplace_requires_grad,
torch.amp.autocast_mode.autocast,
torch.autograd.grad_mode.enable_grad,
torch.autograd.grad_mode.inference_mode,
@ -176,6 +177,8 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
) -> "VariableTracker":
from . import (
DisabledSavedTensorsHooksVariable,
GradIncrementNestingCtxManagerVariable,
GradInplaceRequiresGradCtxManagerVariable,
GradModeVariable,
InferenceModeVariable,
StreamVariable,
@ -241,6 +244,17 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
tx,
[guard_if_dyn(x) for x in args],
)
elif self.value is torch._functorch.eager_transforms.grad_increment_nesting:
assert len(args) == 0
return GradIncrementNestingCtxManagerVariable.create(tx)
elif (
self.value is torch._functorch.eager_transforms.enable_inplace_requires_grad
):
assert len(args) == 1
return GradInplaceRequiresGradCtxManagerVariable.create(
tx,
[guard_if_dyn(x) for x in args],
)
elif self.value is torch.autograd.graph.disable_saved_tensors_hooks:
assert len(args) == 1
return DisabledSavedTensorsHooksVariable.create(
@ -293,11 +307,6 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
):
tx.mark_inconsistent_side_effects()
return ConstantVariable.create(tracing_state_functions[self.value])
elif self.value in (torch._functorch.eager_transforms.grad_impl,):
return TorchHigherOrderOperatorVariable.make(
self.value,
source=self.source,
).call_function(tx, args, kwargs)
elif self.value is torch.overrides.get_default_nowrap_functions.__wrapped__:
# [Note: __torch_function__] we return empty here because we restrict
# the set of functions that we trace __torch_function__ on to

View File

@ -362,3 +362,40 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
def wrapper(*args, **kwargs):
return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
return wrapper
@exposed_in("torch.func")
def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
"""
Returns a function to compute a tuple of the gradient and primal, or
forward, computation.
Args:
func (Callable): A Python function that takes one or more arguments.
Must return a single-element Tensor. If specified ``has_aux``
equals ``True``, function can return a tuple of single-element
Tensor and other auxiliary objects: ``(output, aux)``.
argnums (int or Tuple[int]): Specifies arguments to compute gradients
with respect to. ``argnums`` can be single integer or tuple of
integers. Default: 0.
has_aux (bool): Flag indicating that ``func`` returns a tensor and
other auxiliary objects: ``(output, aux)``. Default: False.
Returns:
Function to compute a tuple of gradients with respect to its inputs
and the forward computation. By default, the output of the function is
a tuple of the gradient tensor(s) with respect to the first argument
and the primal computation. If specified ``has_aux`` equals
``True``, tuple of gradients and tuple of the forward computation with
output auxiliary objects is returned. If ``argnums`` is a tuple of
integers, a tuple of a tuple of the output gradients with respect to
each ``argnums`` value and the forward computation is returned.
See :func:`grad` for examples
"""
from torch._functorch import eager_transforms
@functools.wraps(func)
def wrapper(*args, **kwargs):
return eager_transforms.grad_and_value_impl(func, argnums, has_aux, args, kwargs)
return wrapper

View File

@ -67,7 +67,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
warn_deprecated('grad_and_value')
return _impl.grad_and_value(func, argnums, has_aux)
return apis.grad_and_value(func, argnums, has_aux)
def vjp(func: Callable, *primals, has_aux: bool = False):
warn_deprecated('vjp')
@ -110,7 +110,7 @@ def combine_state_for_ensemble(models):
setup_docs(vmap, apis.vmap, 'torch.vmap')
setup_docs(grad, apis.grad)
setup_docs(grad_and_value)
setup_docs(grad_and_value, apis.grad_and_value)
setup_docs(vjp)
setup_docs(jvp)
setup_docs(jacrev)

View File

@ -41,7 +41,7 @@ from torch._C._functorch import (
_assert_wrapped_functional,
_propagate_functional_input_mutation,
set_inplace_requires_grad_allowed,
get_inplace_requires_grad_allowed
get_inplace_requires_grad_allowed,
)
from torch._functorch.utils import exposed_in, argnums_t
@ -51,7 +51,7 @@ def lazy_dynamo_disable(func):
return torch._dynamo.disable(func)
@contextlib.contextmanager
def enable_inplace_requires_grad(enabled=True):
def enable_inplace_requires_grad(enabled):
prev_state = get_inplace_requires_grad_allowed()
set_inplace_requires_grad_allowed(enabled)
try:
@ -60,11 +60,16 @@ def enable_inplace_requires_grad(enabled=True):
set_inplace_requires_grad_allowed(prev_state)
def _tensor_requires_grad(x):
# avoid graph-break on x.requires_grad_()
# https://github.com/pytorch/pytorch/pull/110053
return x.requires_grad_()
def _create_differentiable(inps, level=None):
def create_differentiable(x):
if isinstance(x, torch.Tensor):
with enable_inplace_requires_grad():
return x.requires_grad_()
with enable_inplace_requires_grad(True):
return _tensor_requires_grad(x)
raise ValueError(f'Thing passed to transform API must be Tensor, '
f'got {type(x)}')
return tree_map(create_differentiable, inps)
@ -277,6 +282,15 @@ def vjp(func: Callable, *primals, has_aux: bool = False):
return _vjp_with_argnums(func, *primals, has_aux=has_aux)
@contextlib.contextmanager
def grad_increment_nesting():
try:
grad_level = _grad_increment_nesting()
yield grad_level
finally:
_grad_decrement_nesting()
@doesnt_support_saved_tensors_hooks
def _vjp_with_argnums(func: Callable, *primals, argnums: Optional[argnums_t] = None, has_aux: bool = False):
# This is the same function as vjp but also accepts an argnums argument
@ -1213,88 +1227,54 @@ def hessian(func, argnums=0):
return jacfwd(jacrev(func, argnums), argnums)
@exposed_in("torch.func")
def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
"""
Returns a function to compute a tuple of the gradient and primal, or
forward, computation.
Args:
func (Callable): A Python function that takes one or more arguments.
Must return a single-element Tensor. If specified ``has_aux``
equals ``True``, function can return a tuple of single-element
Tensor and other auxiliary objects: ``(output, aux)``.
argnums (int or Tuple[int]): Specifies arguments to compute gradients
with respect to. ``argnums`` can be single integer or tuple of
integers. Default: 0.
has_aux (bool): Flag indicating that ``func`` returns a tensor and
other auxiliary objects: ``(output, aux)``. Default: False.
Returns:
Function to compute a tuple of gradients with respect to its inputs
and the forward computation. By default, the output of the function is
a tuple of the gradient tensor(s) with respect to the first argument
and the primal computation. If specified ``has_aux`` equals
``True``, tuple of gradients and tuple of the forward computation with
output auxiliary objects is returned. If ``argnums`` is a tuple of
integers, a tuple of a tuple of the output gradients with respect to
each ``argnums`` value and the forward computation is returned.
See :func:`grad` for examples
"""
@doesnt_support_saved_tensors_hooks
@wraps(func)
def wrapper(*args, **kwargs):
level = _grad_increment_nesting()
try:
output, aux, grad_input = None, None, None
# See NOTE [grad and vjp interaction with no_grad]
with torch.enable_grad():
args = _wrap_all_tensors(args, level)
kwargs = _wrap_all_tensors(kwargs, level)
diff_args = _slice_argnums(args, argnums, as_tuple=False)
tree_map_(partial(_create_differentiable, level=level), diff_args)
output = func(*args, **kwargs)
if has_aux:
if not (isinstance(output, tuple) and len(output) == 2):
raise RuntimeError(
"grad_and_value(f)(*args): output of function f should be a tuple: (output, aux) "
"if has_aux is True"
)
output, aux = output
if not isinstance(output, torch.Tensor):
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
f'to return a Tensor, got {type(output)}')
if output.dim() != 0:
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
'to return a scalar Tensor, got tensor with '
f'{output.dim()} dims. Maybe you wanted to '
'use the vjp or jacrev APIs instead?')
flat_diff_args, spec = tree_flatten(diff_args)
# NB: need create_graph so that backward pass isn't run in no_grad mode
flat_outputs = _as_tuple(output)
flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True)
grad_input = tree_unflatten(flat_grad_input, spec)
grad_input = _undo_create_differentiable(grad_input, level)
output = _undo_create_differentiable(output, level)
if aux is not None:
aux = _undo_create_differentiable(aux, level)
@doesnt_support_saved_tensors_hooks
def grad_and_value_impl(func, argnums, has_aux, args, kwargs) -> Callable:
with grad_increment_nesting() as level:
output, aux, grad_input = None, None, None
# See NOTE [grad and vjp interaction with no_grad]
with torch.enable_grad():
args = _wrap_all_tensors(args, level)
kwargs = _wrap_all_tensors(kwargs, level)
diff_args = _slice_argnums(args, argnums, as_tuple=False)
tree_map_(partial(_create_differentiable, level=level), diff_args)
output = func(*args, **kwargs)
if has_aux:
return grad_input, (output, aux)
return grad_input, output
finally:
_grad_decrement_nesting()
return wrapper
if not (isinstance(output, tuple) and len(output) == 2):
raise RuntimeError(
"grad_and_value(f)(*args): output of function f should be a tuple: (output, aux) "
"if has_aux is True"
)
output, aux = output
if not isinstance(output, torch.Tensor):
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
f'to return a Tensor, got {type(output)}')
if output.dim() != 0:
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
'to return a scalar Tensor, got tensor with '
f'{output.dim()} dims. Maybe you wanted to '
'use the vjp or jacrev APIs instead?')
flat_diff_args, spec = tree_flatten(diff_args)
# NB: need create_graph so that backward pass isn't run in no_grad mode
flat_outputs = _as_tuple(output)
flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True)
grad_input = tree_unflatten(flat_grad_input, spec)
grad_input = _undo_create_differentiable(grad_input, level)
output = _undo_create_differentiable(output, level)
if has_aux:
aux = _undo_create_differentiable(aux, level)
if has_aux:
return grad_input, (output, aux)
return grad_input, output
def grad_impl(func: Callable, argnums: argnums_t, has_aux: bool, args, kwargs):
func = lazy_dynamo_disable(func)
results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)
if has_aux:
grad, (_, aux) = results
return grad, aux

View File

@ -14,6 +14,7 @@ from weakref import ReferenceType
import torch
import torch._custom_op
import torch._logging
from torch._C._functorch import is_functorch_wrapped_tensor
from torch._guards import Source
from torch._ops import OpOverload
@ -124,7 +125,7 @@ def is_fake(x):
reapply_views = torch._C._functionalization_reapply_views_tls()
unwrapped = torch._C._functorch._unwrap_functional_tensor(x, reapply_views)
return is_fake(unwrapped)
elif isinstance(x, torch.Tensor) and torch._C._functorch.is_batchedtensor(x):
elif isinstance(x, torch.Tensor) and is_functorch_wrapped_tensor(x):
unwrapped = torch._C._functorch.get_unwrapped(x)
return is_fake(unwrapped)
return False
@ -145,7 +146,7 @@ def maybe_get_fake_mode(t):
reapply_views = torch._C._functionalization_reapply_views_tls()
unwrapped = torch._C._functorch._unwrap_functional_tensor(t, reapply_views)
return maybe_get_fake_mode(unwrapped)
elif isinstance(t, torch.Tensor) and torch._C._functorch.is_batchedtensor(t):
elif isinstance(t, torch.Tensor) and is_functorch_wrapped_tensor(t):
unwrapped = torch._C._functorch.get_unwrapped(t)
return maybe_get_fake_mode(unwrapped)
return None

View File

@ -11,6 +11,8 @@ from torch._C._functorch import (
current_level,
get_unwrapped,
is_batchedtensor,
is_functorch_wrapped_tensor,
is_gradtrackingtensor,
maybe_get_bdim,
maybe_get_level,
peek_interpreter_stack,
@ -133,7 +135,7 @@ class MetaConverter:
# hold a weak ref to self, otherwise it will be kept alive
# by the del_ten closure
self_weak_ref = weakref.ref(self)
if t.is_sparse or t.is_mkldnn or is_batchedtensor(t):
if t.is_sparse or t.is_mkldnn or is_functorch_wrapped_tensor(t):
weak_st = None
else:
weak_st = StorageWeakRef(t._typed_storage())
@ -327,16 +329,36 @@ class MetaConverter:
if t.requires_grad and not is_leaf:
with torch.enable_grad():
r = r.clone()
elif is_batchedtensor(t):
# Wraps a BatchedTensor in a FakeTensor
elif is_functorch_wrapped_tensor(t):
if t._is_view():
from torch._dynamo.exc import unimplemented
unimplemented(
"view functorch tensors are not supported by meta conversion"
)
# Wraps a functorch tensor class (BatchedTensor, GradTrackingTensor)
# in a FakeTensor
def _to_fake_tensor(t):
if is_batchedtensor(t):
ft = _to_fake_tensor(get_unwrapped(t))
lvl = maybe_get_level(t)
bdim = maybe_get_bdim(t)
r = _add_batch_dim(ft, bdim, lvl)
elif is_gradtrackingtensor(t):
disable_functorch = torch._C._DisableFuncTorch
with disable_functorch():
ft = _to_fake_tensor(get_unwrapped(t))
lvl = torch._C._functorch.maybe_get_level(t)
r = torch._C._functorch._wrap_for_grad(ft, lvl)
is_leaf = safe_is_leaf(t)
if t.requires_grad and safe_is_leaf(r):
r.requires_grad = True
elif t.requires_grad and not is_leaf:
with torch.enable_grad():
r = r.clone()
else:
# regular tensor
sizes = t.size()
strides = t.stride()
r = callback(
@ -537,6 +559,7 @@ class MetaConverter:
device="meta",
)
)
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
if t.requires_grad:
r.requires_grad = t.requires_grad
@ -549,8 +572,8 @@ class MetaConverter:
r = r.clone(memory_format=torch.preserve_format)
# Graph-Break for wrapped tensors
if not is_batchedtensor(
t
if not (
is_batchedtensor(t) or is_gradtrackingtensor(t)
) and torch._C._functorch.is_functorch_wrapped_tensor(t):
return NotImplemented
@ -701,13 +724,16 @@ class MetaConverter:
return NotImplemented
else:
self.hit += 1
r = self.meta_tensor(
t,
shape_env=shape_env,
callback=callback,
source=source,
symbolic_context=symbolic_context,
)
disable_functorch = torch._C._DisableFuncTorch
with disable_functorch():
r = self.meta_tensor(
t,
shape_env=shape_env,
callback=callback,
source=source,
symbolic_context=symbolic_context,
)
if type(t) is torch.nn.Parameter:
# NB: Cannot directly use Parameter constructor
# because that would force a detach, not desirable

View File

@ -1,5 +1,4 @@
from torch._functorch.eager_transforms import (
grad_and_value,
vjp,
jvp,
jacrev,
@ -8,7 +7,7 @@ from torch._functorch.eager_transforms import (
functionalize,
linearize
)
from torch._functorch.apis import grad
from torch._functorch.apis import grad, grad_and_value
from torch._functorch.functional_call import functional_call, stack_module_state
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
from torch._functorch.apis import vmap

View File

@ -5018,6 +5018,7 @@ def munge_exc(e, *, suppress_suffix=True, suppress_prefix=True, file=None, skip=
s = re.sub(r' File "([^"]+)", line \d+, in (.+)\n .+\n( +[~^]+ *\n)?', repl_frame, s)
s = re.sub(r"line \d+", "line N", s)
s = re.sub(r".py:\d+", ".py:N", s)
s = re.sub(file, os.path.basename(file), s)
s = re.sub(os.path.join(os.path.dirname(torch.__file__), ""), "", s)
s = re.sub(r"\\", "/", s) # for Windows