mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Let torch dynamo inline torch.func.grad (#118407)
When dynamo sees torch.func.grad, it tries to inline all frames related to. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118407 Approved by: https://github.com/zou3519
This commit is contained in:
parent
5472923998
commit
491c2b4665
|
|
@ -26,6 +26,7 @@ from torch._dynamo.testing import (
|
|||
from torch._dynamo.utils import counters, ifdynstaticdefault
|
||||
from torch._higher_order_ops.wrap import wrap
|
||||
from torch.testing._internal.common_utils import (
|
||||
munge_exc,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
xfailIfTorchDynamo,
|
||||
)
|
||||
|
|
@ -2259,9 +2260,7 @@ class GraphModule(torch.nn.Module):
|
|||
actual_stack = self._get_source_fn_stack(gm, {"sum_1", "sin"})
|
||||
self.assertExpectedInline(
|
||||
pprint.pformat(actual_stack),
|
||||
"""\
|
||||
{'sin': ['grad_impl', 'grad_impl', 'sin'],
|
||||
'sum_1': ['grad_impl', 'grad_impl', 'sum_1']}""",
|
||||
"""{'sin': ['sin']}""",
|
||||
)
|
||||
|
||||
@config.patch(capture_func_transforms=True)
|
||||
|
|
@ -2377,6 +2376,59 @@ def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytre
|
|||
|
||||
|
||||
class HigherOrderOpVmapGuardTests(LoggingTestCase):
|
||||
@config.patch(capture_func_transforms=True)
|
||||
@make_logging_test(recompiles=True)
|
||||
def test_vmap_grad_guard_ok(self, records):
|
||||
vmap = torch.vmap
|
||||
grad = torch.func.grad
|
||||
|
||||
def g(x):
|
||||
return vmap(grad(torch.sin))(x)
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
return vmap(g)(x)
|
||||
|
||||
x = torch.randn(4, 5)
|
||||
y = fn(x)
|
||||
# sanity check
|
||||
self.assertEqual(len(records), 0)
|
||||
self.assertEqual(x.cos(), y)
|
||||
|
||||
# Calling the same function again won't have any effect on guards
|
||||
fn(x)
|
||||
self.assertEqual(len(records), 0)
|
||||
|
||||
@xfailIfTorchDynamo
|
||||
@config.patch(capture_func_transforms=True)
|
||||
@make_logging_test(recompiles=True)
|
||||
def test_grad_guard_fail(self, records):
|
||||
grad = torch.func.grad
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
return grad(torch.sin)(x.sum())
|
||||
|
||||
x = torch.randn([])
|
||||
fn(x)
|
||||
self.assertEqual(len(records), 0)
|
||||
|
||||
# calling again should not invalidate the graph
|
||||
fn(x)
|
||||
self.assertEqual(len(records), 0)
|
||||
|
||||
# call grad should retrigger compilation
|
||||
x = torch.randn(3)
|
||||
grad(fn)(x)
|
||||
self.assertGreater(len(records), 0)
|
||||
record = self.getRecord(records, "pyfunctorch")
|
||||
self.assertIn(
|
||||
"""\
|
||||
triggered by the following guard failure(s):
|
||||
- torch._functorch.pyfunctorch.compare_functorch_state([])""",
|
||||
munge_exc(record.getMessage()),
|
||||
)
|
||||
|
||||
@config.patch(capture_func_transforms=True)
|
||||
@make_logging_test(recompiles=True)
|
||||
def test_vmap_guard_ok(self, records):
|
||||
|
|
@ -2452,6 +2504,59 @@ class HigherOrderOpVmapGuardTests(LoggingTestCase):
|
|||
record.getMessage(),
|
||||
)
|
||||
|
||||
@xfailIfTorchDynamo
|
||||
@config.patch(capture_func_transforms=True)
|
||||
@make_logging_test(recompiles=True)
|
||||
def test_vmap_grad_vmap_guard_fail(self, records):
|
||||
vmap = torch.vmap
|
||||
grad = torch.func.grad
|
||||
|
||||
def g(x):
|
||||
y = vmap(torch.sin, randomness="same")(x)
|
||||
return y.sum(0)
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
return grad(g)(x)
|
||||
|
||||
x = torch.randn(3, 3)
|
||||
y = vmap(fn, randomness="error")(x)
|
||||
self.assertEqual(x.cos(), y)
|
||||
|
||||
# previous FX graph should be invalidated
|
||||
x = torch.randn(3, 3, 4)
|
||||
y = vmap(vmap(fn, randomness="different"))(x)
|
||||
self.assertGreater(len(records), 0)
|
||||
record = self.getRecord(records, "pyfunctorch")
|
||||
self.assertIn(
|
||||
"""\
|
||||
triggered by the following guard failure(s):
|
||||
- torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'error')])""",
|
||||
munge_exc(record.getMessage()),
|
||||
)
|
||||
|
||||
@xfailIfTorchDynamo
|
||||
@config.patch(capture_func_transforms=True)
|
||||
@make_logging_test(recompiles=True)
|
||||
def test_vmap_recompile_different_states(self, records):
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
return torch.vmap(lambda x: x.sin())(x)
|
||||
|
||||
x = torch.zeros(3, 3, 4, 5)
|
||||
y = torch.vmap(fn, randomness="same")(x)
|
||||
self.assertEqual(len(records), 0) # sanity check
|
||||
|
||||
y = torch.vmap(fn, randomness="different")(x)
|
||||
self.assertGreater(len(records), 0)
|
||||
record = self.getRecord(records, "pyfunctorch")
|
||||
self.assertIn(
|
||||
"""\
|
||||
triggered by the following guard failure(s):
|
||||
- torch._functorch.pyfunctorch.compare_functorch_state([('Vmap', 1, 'same')])""",
|
||||
munge_exc(record.getMessage()),
|
||||
)
|
||||
|
||||
|
||||
class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
||||
def tearDown(self):
|
||||
|
|
@ -2510,19 +2615,32 @@ class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
|
|||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_ : torch.Tensor):
|
||||
l_x_ = L_x_
|
||||
child = L_x_
|
||||
|
||||
grad_body_0 = self.grad_body_0
|
||||
grad_proxy = torch.func.grad(grad_body_0, 0, False); grad_body_0 = None
|
||||
call = grad_proxy.__call__(l_x_); grad_proxy = l_x_ = None
|
||||
contiguous = call.contiguous(); call = None
|
||||
return (contiguous,)
|
||||
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
|
||||
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_):
|
||||
sin = l_x_.sin(); l_x_ = None
|
||||
sum_1 = sin.sum(); sin = None
|
||||
return sum_1
|
||||
diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None
|
||||
|
||||
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
|
||||
|
||||
_tensor_requires_grad = torch._functorch.eager_transforms._tensor_requires_grad(diff_args)
|
||||
|
||||
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
|
||||
|
||||
sin = diff_args.sin()
|
||||
output = sin.sum(); sin = None
|
||||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
||||
grad_input = _autograd_grad[0]; _autograd_grad = None
|
||||
|
||||
grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
||||
|
||||
_ = torch._C._functorch._unwrap_for_grad(output, 1); output = None
|
||||
|
||||
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
|
||||
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||
return (grad,)
|
||||
""",
|
||||
)
|
||||
|
||||
|
|
@ -2566,20 +2684,33 @@ class GraphModule(torch.nn.Module):
|
|||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_ : torch.Tensor):
|
||||
l_x_ = L_x_
|
||||
child = L_x_
|
||||
|
||||
grad_body_0 = self.grad_body_0
|
||||
grad_proxy = torch.func.grad(grad_body_0, 0, False); grad_body_0 = None
|
||||
call = grad_proxy.__call__(l_x_); grad_proxy = l_x_ = None
|
||||
contiguous = call.contiguous(); call = None
|
||||
return (contiguous,)
|
||||
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
|
||||
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_):
|
||||
sin = l_x_.sin(); l_x_ = None
|
||||
add = sin + 3; sin = None
|
||||
sum_1 = add.sum(); add = None
|
||||
return sum_1
|
||||
diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None
|
||||
|
||||
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
|
||||
|
||||
_tensor_requires_grad = torch._functorch.eager_transforms._tensor_requires_grad(diff_args)
|
||||
|
||||
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
|
||||
|
||||
sin = diff_args.sin()
|
||||
add = sin + 3; sin = None
|
||||
output = add.sum(); add = None
|
||||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
||||
grad_input = _autograd_grad[0]; _autograd_grad = None
|
||||
|
||||
grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
||||
|
||||
_ = torch._C._functorch._unwrap_for_grad(output, 1); output = None
|
||||
|
||||
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
|
||||
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||
return (grad,)
|
||||
""",
|
||||
)
|
||||
|
||||
|
|
@ -2609,22 +2740,35 @@ class GraphModule(torch.nn.Module):
|
|||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_ : torch.Tensor):
|
||||
l_x_ = L_x_
|
||||
child = L_x_
|
||||
|
||||
y = torch.randn(3)
|
||||
|
||||
grad_body_0 = self.grad_body_0
|
||||
grad_proxy = torch.func.grad(grad_body_0, 0, False); grad_body_0 = None
|
||||
call = grad_proxy.__call__(l_x_, y); grad_proxy = l_x_ = None
|
||||
contiguous = call.contiguous(); call = None
|
||||
return (y, contiguous)
|
||||
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
|
||||
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_, y):
|
||||
sin = l_x_.sin(); l_x_ = None
|
||||
add = sin + y; sin = y = None
|
||||
sum_1 = add.sum(); add = None
|
||||
return sum_1
|
||||
diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None
|
||||
|
||||
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
|
||||
|
||||
_tensor_requires_grad = torch._functorch.eager_transforms._tensor_requires_grad(diff_args)
|
||||
|
||||
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
|
||||
|
||||
sin = diff_args.sin()
|
||||
add = sin + y; sin = None
|
||||
output = add.sum(); add = None
|
||||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
||||
grad_input = _autograd_grad[0]; _autograd_grad = None
|
||||
|
||||
grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
||||
|
||||
_ = torch._C._functorch._unwrap_for_grad(output, 1); output = None
|
||||
|
||||
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
|
||||
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||
return (grad, y)
|
||||
""",
|
||||
)
|
||||
|
||||
|
|
@ -2656,20 +2800,33 @@ class GraphModule(torch.nn.Module):
|
|||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_ : torch.Tensor):
|
||||
l_x_ = L_x_
|
||||
child = L_x_
|
||||
|
||||
grad_body_0 = self.grad_body_0
|
||||
grad_proxy = torch.func.grad(grad_body_0, 0, False); grad_body_0 = None
|
||||
call = grad_proxy.__call__(l_x_); grad_proxy = l_x_ = None
|
||||
contiguous = call.contiguous(); call = None
|
||||
return (contiguous,)
|
||||
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
|
||||
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_):
|
||||
sin = l_x_.sin(); l_x_ = None
|
||||
add = sin + 3.14; sin = None
|
||||
sum_1 = add.sum(); add = None
|
||||
return sum_1
|
||||
diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None
|
||||
|
||||
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
|
||||
|
||||
_tensor_requires_grad = torch._functorch.eager_transforms._tensor_requires_grad(diff_args)
|
||||
|
||||
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
|
||||
|
||||
sin = diff_args.sin()
|
||||
add = sin + 3.14; sin = None
|
||||
output = add.sum(); add = None
|
||||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
||||
grad_input = _autograd_grad[0]; _autograd_grad = None
|
||||
|
||||
grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
||||
|
||||
_ = torch._C._functorch._unwrap_for_grad(output, 1); output = None
|
||||
|
||||
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
|
||||
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||
return (grad,)
|
||||
""",
|
||||
)
|
||||
|
||||
|
|
@ -2698,23 +2855,36 @@ class GraphModule(torch.nn.Module):
|
|||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_ : torch.Tensor):
|
||||
l_x_ = L_x_
|
||||
child = L_x_
|
||||
|
||||
grad_body_0 = self.grad_body_0
|
||||
grad_proxy = torch.func.grad(grad_body_0, 0, True); grad_body_0 = None
|
||||
call = grad_proxy.__call__(l_x_); grad_proxy = l_x_ = None
|
||||
getitem = call[0]
|
||||
getitem_1 = call[1]; call = None
|
||||
contiguous = getitem.contiguous(); getitem = None
|
||||
return (contiguous, getitem_1)
|
||||
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
|
||||
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_):
|
||||
sin = l_x_.sin()
|
||||
add = sin + 3.14; sin = None
|
||||
sum_1 = add.sum(); add = None
|
||||
cos = l_x_.cos(); l_x_ = None
|
||||
return (sum_1, cos)
|
||||
diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None
|
||||
|
||||
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
|
||||
|
||||
_tensor_requires_grad = torch._functorch.eager_transforms._tensor_requires_grad(diff_args)
|
||||
|
||||
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
|
||||
|
||||
sin = diff_args.sin()
|
||||
add = sin + 3.14; sin = None
|
||||
output = add.sum(); add = None
|
||||
aux = diff_args.cos()
|
||||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
||||
grad_input = _autograd_grad[0]; _autograd_grad = None
|
||||
|
||||
grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
||||
|
||||
_ = torch._C._functorch._unwrap_for_grad(output, 1); output = None
|
||||
|
||||
aux_2 = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
|
||||
|
||||
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
|
||||
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||
return (grad, aux_2)
|
||||
""",
|
||||
)
|
||||
|
||||
|
|
@ -2742,24 +2912,38 @@ class GraphModule(torch.nn.Module):
|
|||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
child = L_x_
|
||||
child_1 = L_y_
|
||||
|
||||
grad_body_0 = self.grad_body_0
|
||||
grad_proxy = torch.func.grad(grad_body_0, 0, True); grad_body_0 = None
|
||||
call = grad_proxy.__call__(l_x_, l_y_); grad_proxy = l_x_ = l_y_ = None
|
||||
getitem = call[0]
|
||||
getitem_1 = call[1]; call = None
|
||||
contiguous = getitem.contiguous(); getitem = None
|
||||
return (contiguous, getitem_1)
|
||||
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
|
||||
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_, l_y_):
|
||||
sin = l_x_.sin()
|
||||
add = sin + l_y_; sin = l_y_ = None
|
||||
sum_1 = add.sum(); add = None
|
||||
cos = l_x_.cos(); l_x_ = None
|
||||
return (sum_1, cos)
|
||||
diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None
|
||||
_wrap_for_grad_1 = torch._C._functorch._wrap_for_grad(child_1, 1); child_1 = None
|
||||
|
||||
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
|
||||
|
||||
_tensor_requires_grad = torch._functorch.eager_transforms._tensor_requires_grad(diff_args)
|
||||
|
||||
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
|
||||
|
||||
sin = diff_args.sin()
|
||||
add = sin + _wrap_for_grad_1; sin = _wrap_for_grad_1 = None
|
||||
output = add.sum(); add = None
|
||||
aux = diff_args.cos()
|
||||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
||||
grad_input = _autograd_grad[0]; _autograd_grad = None
|
||||
|
||||
grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
||||
|
||||
_ = torch._C._functorch._unwrap_for_grad(output, 1); output = None
|
||||
|
||||
aux_2 = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
|
||||
|
||||
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
|
||||
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||
return (grad, aux_2)
|
||||
""",
|
||||
)
|
||||
|
||||
|
|
@ -2798,27 +2982,45 @@ class GraphModule(torch.nn.Module):
|
|||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
child = L_x_
|
||||
child_1 = L_y_
|
||||
|
||||
grad_body_0 = self.grad_body_0
|
||||
grad_proxy = torch.func.grad(grad_body_0, (0, 1), True); grad_body_0 = None
|
||||
call = grad_proxy.__call__(l_x_, l_y_); grad_proxy = l_x_ = l_y_ = None
|
||||
getitem = call[0]
|
||||
getitem_1 = getitem[0]
|
||||
getitem_2 = getitem[1]; getitem = None
|
||||
getitem_3 = call[1]; call = None
|
||||
contiguous = getitem_1.contiguous(); getitem_1 = None
|
||||
contiguous_1 = getitem_2.contiguous(); getitem_2 = None
|
||||
return (contiguous, contiguous_1, getitem_3)
|
||||
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
|
||||
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_, l_y_):
|
||||
sin = l_x_.sin()
|
||||
add = sin + l_y_; sin = l_y_ = None
|
||||
sum_1 = add.sum(); add = None
|
||||
cos = l_x_.cos(); l_x_ = None
|
||||
return (sum_1, cos)
|
||||
child_4 = torch._C._functorch._wrap_for_grad(child, 1); child = None
|
||||
child_5 = torch._C._functorch._wrap_for_grad(child_1, 1); child_1 = None
|
||||
|
||||
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
|
||||
|
||||
_tensor_requires_grad = torch._functorch.eager_transforms._tensor_requires_grad(child_4)
|
||||
|
||||
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
|
||||
set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True)
|
||||
|
||||
_tensor_requires_grad_1 = torch._functorch.eager_transforms._tensor_requires_grad(child_5)
|
||||
|
||||
set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
|
||||
|
||||
sin = child_4.sin()
|
||||
add = sin + child_5; sin = None
|
||||
output = add.sum(); add = None
|
||||
aux = child_4.cos()
|
||||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child_4, child_5], create_graph = True); child_4 = child_5 = None
|
||||
child_6 = _autograd_grad[0]
|
||||
child_7 = _autograd_grad[1]; _autograd_grad = None
|
||||
|
||||
_unwrap_for_grad = torch._C._functorch._unwrap_for_grad(child_6, 1); child_6 = None
|
||||
_unwrap_for_grad_1 = torch._C._functorch._unwrap_for_grad(child_7, 1); child_7 = None
|
||||
|
||||
_ = torch._C._functorch._unwrap_for_grad(output, 1); output = None
|
||||
|
||||
aux_2 = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
|
||||
|
||||
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
|
||||
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||
return (_unwrap_for_grad, _unwrap_for_grad_1, aux_2)
|
||||
""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
|
|
@ -2826,27 +3028,45 @@ class GraphModule(torch.nn.Module):
|
|||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
child = L_x_
|
||||
child_1 = L_y_
|
||||
|
||||
grad_body_0 = self.grad_body_0
|
||||
grad_proxy = torch.func.grad(grad_body_0, (0, 1), True); grad_body_0 = None
|
||||
call = grad_proxy.__call__(l_x_, l_y_); grad_proxy = l_x_ = l_y_ = None
|
||||
getitem = call[0]
|
||||
getitem_1 = getitem[0]
|
||||
getitem_2 = getitem[1]; getitem = None
|
||||
getitem_3 = call[1]; call = None
|
||||
contiguous = getitem_1.contiguous(); getitem_1 = None
|
||||
contiguous_1 = getitem_2.contiguous(); getitem_2 = None
|
||||
return (contiguous, contiguous_1, getitem_3)
|
||||
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
|
||||
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_, l_y_):
|
||||
sin = l_x_.sin()
|
||||
add = sin + l_y_; sin = l_y_ = None
|
||||
sum_1 = add.sum(); add = None
|
||||
cos = l_x_.cos(); l_x_ = None
|
||||
return (sum_1, cos)
|
||||
child_4 = torch._C._functorch._wrap_for_grad(child, 1); child = None
|
||||
child_5 = torch._C._functorch._wrap_for_grad(child_1, 1); child_1 = None
|
||||
|
||||
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
|
||||
|
||||
_tensor_requires_grad = torch._functorch.eager_transforms._tensor_requires_grad(child_4)
|
||||
|
||||
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
|
||||
set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True)
|
||||
|
||||
_tensor_requires_grad_1 = torch._functorch.eager_transforms._tensor_requires_grad(child_5)
|
||||
|
||||
set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
|
||||
|
||||
sin = child_4.sin()
|
||||
add = sin + child_5; sin = None
|
||||
output = add.sum(); add = None
|
||||
aux = child_4.cos()
|
||||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [child_4, child_5], create_graph = True); child_4 = child_5 = None
|
||||
child_6 = _autograd_grad[0]
|
||||
child_7 = _autograd_grad[1]; _autograd_grad = None
|
||||
|
||||
_unwrap_for_grad = torch._C._functorch._unwrap_for_grad(child_6, 1); child_6 = None
|
||||
_unwrap_for_grad_1 = torch._C._functorch._unwrap_for_grad(child_7, 1); child_7 = None
|
||||
|
||||
_ = torch._C._functorch._unwrap_for_grad(output, 1); output = None
|
||||
|
||||
aux_2 = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None
|
||||
|
||||
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
|
||||
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||
return (_unwrap_for_grad, _unwrap_for_grad_1, aux_2)
|
||||
""",
|
||||
)
|
||||
|
||||
|
|
@ -2872,27 +3092,52 @@ class GraphModule(torch.nn.Module):
|
|||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_ : torch.Tensor):
|
||||
l_x_ = L_x_
|
||||
child = L_x_
|
||||
|
||||
grad_body_1 = self.grad_body_1
|
||||
grad_proxy = torch.func.grad(grad_body_1, 0, False); grad_body_1 = None
|
||||
call = grad_proxy.__call__(l_x_); grad_proxy = l_x_ = None
|
||||
contiguous = call.contiguous(); call = None
|
||||
return (contiguous,)
|
||||
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
|
||||
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_):
|
||||
grad_body_0 = self.grad_body_0
|
||||
grad_proxy = torch.func.grad(grad_body_0, 0, False); grad_body_0 = None
|
||||
call = grad_proxy.__call__(l_x_); grad_proxy = l_x_ = None
|
||||
contiguous = call.contiguous(); call = None
|
||||
return contiguous
|
||||
child_1 = torch._C._functorch._wrap_for_grad(child, 1); child = None
|
||||
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_):
|
||||
sin = l_x_.sin(); l_x_ = None
|
||||
sum_1 = sin.sum(); sin = None
|
||||
return sum_1
|
||||
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
|
||||
|
||||
_tensor_requires_grad = torch._functorch.eager_transforms._tensor_requires_grad(child_1)
|
||||
|
||||
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
|
||||
_saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||
_grad_increment_nesting_1 = torch._C._functorch._grad_increment_nesting()
|
||||
|
||||
diff_args_1 = torch._C._functorch._wrap_for_grad(child_1, 2)
|
||||
|
||||
set_inplace_requires_grad_allowed_2 = torch._C._functorch.set_inplace_requires_grad_allowed(True)
|
||||
|
||||
_tensor_requires_grad_1 = torch._functorch.eager_transforms._tensor_requires_grad(diff_args_1)
|
||||
|
||||
set_inplace_requires_grad_allowed_3 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
|
||||
|
||||
sin = diff_args_1.sin()
|
||||
output = sin.sum(); sin = None
|
||||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args_1], create_graph = True); diff_args_1 = None
|
||||
grad_input = _autograd_grad[0]; _autograd_grad = None
|
||||
|
||||
output_2 = torch._C._functorch._unwrap_for_grad(grad_input, 2); grad_input = None
|
||||
|
||||
_ = torch._C._functorch._unwrap_for_grad(output, 2); output = None
|
||||
|
||||
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
|
||||
_saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||
|
||||
_autograd_grad_1 = torch._functorch.eager_transforms._autograd_grad((output_2,), [child_1], create_graph = True); child_1 = None
|
||||
grad_input_2 = _autograd_grad_1[0]; _autograd_grad_1 = None
|
||||
|
||||
grad_1 = torch._C._functorch._unwrap_for_grad(grad_input_2, 1); grad_input_2 = None
|
||||
|
||||
__1 = torch._C._functorch._unwrap_for_grad(output_2, 1); output_2 = None
|
||||
|
||||
_grad_decrement_nesting_1 = torch._C._functorch._grad_decrement_nesting()
|
||||
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||
return (grad_1,)
|
||||
""",
|
||||
)
|
||||
|
||||
|
|
@ -2929,14 +3174,7 @@ class GraphModule(torch.nn.Module):
|
|||
x = torch.randn(3, 3, 3)
|
||||
actual = wrapper_fn(x)
|
||||
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x)
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
assert_dict_matches_regex(
|
||||
self,
|
||||
dict(counters["graph_break"]),
|
||||
{
|
||||
r".*HigherOrderOperator: Mutating a variable not in the current scope \(SideEffects\)": 2
|
||||
},
|
||||
)
|
||||
self.assertEqual(len(counters["graph_break"]), 0)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
@config.patch(capture_func_transforms=True)
|
||||
|
|
@ -2956,12 +3194,7 @@ class GraphModule(torch.nn.Module):
|
|||
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(
|
||||
(x1, x2)
|
||||
)
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
assert_dict_matches_regex(
|
||||
self,
|
||||
dict(counters["graph_break"]),
|
||||
{".*torch.func.grad with body that accepts non-Tensors as input": 2},
|
||||
)
|
||||
self.assertEqual(len(counters["graph_break"]), 0)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
@config.patch(capture_func_transforms=True)
|
||||
|
|
@ -2988,20 +3221,33 @@ class GraphModule(torch.nn.Module):
|
|||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_ : torch.Tensor):
|
||||
l_x_ = L_x_
|
||||
child = L_x_
|
||||
|
||||
grad_body_0 = self.grad_body_0
|
||||
grad_proxy = torch.func.grad(grad_body_0, 0, False); grad_body_0 = None
|
||||
call = grad_proxy.__call__(l_x_, 3.0); grad_proxy = l_x_ = None
|
||||
contiguous = call.contiguous(); call = None
|
||||
return (contiguous,)
|
||||
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
|
||||
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
|
||||
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, l_x_, const):
|
||||
sin = l_x_.sin(); l_x_ = None
|
||||
sum_1 = sin.sum(); sin = None
|
||||
add = sum_1 + 3.0; sum_1 = None
|
||||
return add
|
||||
diff_args = torch._C._functorch._wrap_for_grad(child, 1); child = None
|
||||
|
||||
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
|
||||
|
||||
_tensor_requires_grad = torch._functorch.eager_transforms._tensor_requires_grad(diff_args)
|
||||
|
||||
set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False)
|
||||
|
||||
sin = diff_args.sin()
|
||||
sum_1 = sin.sum(); sin = None
|
||||
output = sum_1 + 3.0; sum_1 = None
|
||||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None
|
||||
grad_input = _autograd_grad[0]; _autograd_grad = None
|
||||
|
||||
grad = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None
|
||||
|
||||
_ = torch._C._functorch._unwrap_for_grad(output, 1); output = None
|
||||
|
||||
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
|
||||
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
|
||||
return (grad,)
|
||||
""",
|
||||
)
|
||||
|
||||
|
|
@ -3045,11 +3291,7 @@ class GraphModule(torch.nn.Module):
|
|||
y = torch.randn(3, 3)
|
||||
actual = wrapper_fn(x, y)
|
||||
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(x, y)
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
self.assertEqual(
|
||||
dict(counters["graph_break"]),
|
||||
{"torch.func.grad: kwargs arguments are currently unsupported.": 2},
|
||||
)
|
||||
self.assertEqual(len(counters["graph_break"]), 0)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
@config.patch(capture_func_transforms=True)
|
||||
|
|
@ -3064,6 +3306,7 @@ class GraphModule(torch.nn.Module):
|
|||
# should not recompile on second call. See Pytorch issue #118493
|
||||
y = torch.vmap(fn)(x)
|
||||
|
||||
@xfailIfTorchDynamo
|
||||
@config.patch(capture_func_transforms=True)
|
||||
@config.patch(error_on_recompile=True)
|
||||
def test_vmap_recompile_different_config(self):
|
||||
|
|
@ -3100,6 +3343,18 @@ class GraphModule(torch.nn.Module):
|
|||
with self.assertRaises(torch._dynamo.exc.RecompileError):
|
||||
torch.vmap(fn, randomness="different")(x)
|
||||
|
||||
@config.patch(capture_func_transforms=True)
|
||||
@config.patch(error_on_recompile=True)
|
||||
def test_grad_recompile(self):
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
return torch.func.grad(torch.sin)(x)
|
||||
|
||||
x = torch.randn([])
|
||||
torch.func.grad(fn)(x)
|
||||
# should not recompile on second call
|
||||
torch.func.grad(fn)(x)
|
||||
|
||||
@config.patch(capture_func_transforms=True)
|
||||
def test_vmap_get_wrapped(self):
|
||||
counters.clear()
|
||||
|
|
|
|||
|
|
@ -1670,7 +1670,7 @@ class TestJac(TestCase):
|
|||
|
||||
self.assertEqual(jacrev(func)(x), torch.autograd.functional.jacobian(func, x))
|
||||
|
||||
@FIXME_jacrev_only
|
||||
@jacrev_and_jacfwd
|
||||
def test_diff_numel(self, device, jacapi):
|
||||
x = torch.randn(2, 4, device=device)
|
||||
|
||||
|
|
@ -1687,14 +1687,14 @@ class TestJac(TestCase):
|
|||
expected[2, 0, 0, 3] = 1
|
||||
self.assertEqual(y, expected)
|
||||
|
||||
@FIXME_jacrev_only
|
||||
@jacrev_and_jacfwd
|
||||
def test_vmap_on_jac_simple(self, device, jacapi):
|
||||
x = torch.randn(2, 3, device=device)
|
||||
y = vmap(jacapi(torch.sin))(x)
|
||||
expected = torch.stack([torch.diagflat(x[i].cos()) for i in range(2)])
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
@FIXME_jacrev_only
|
||||
@jacrev_and_jacfwd
|
||||
def test_nested_jac_simple(self, device, jacapi):
|
||||
def foo(x):
|
||||
return x.sin().sum()
|
||||
|
|
@ -1755,7 +1755,7 @@ class TestJac(TestCase):
|
|||
self.assertTrue(isinstance(z[0], tuple))
|
||||
self.assertEqual(z, ((expected_out0_x,), (expected_out1_x,)))
|
||||
|
||||
@FIXME_jacrev_only
|
||||
@jacrev_and_jacfwd
|
||||
def test_multiple_outputs_pytree(self, device, jacapi):
|
||||
def f(x, y):
|
||||
return {'left': 2 * x + 3 * y, 'right': 4 * x + 5 * y}
|
||||
|
|
@ -1816,7 +1816,7 @@ class TestJac(TestCase):
|
|||
self.assertEqual(result.dim(), 2)
|
||||
self.assertEqual(result, x.new_ones(1, 1))
|
||||
|
||||
@FIXME_jacrev_only
|
||||
@jacrev_and_jacfwd
|
||||
def test_aux_tensor(self, device, jacapi):
|
||||
def f(x):
|
||||
y = x.clone()
|
||||
|
|
@ -1912,7 +1912,7 @@ class TestJac(TestCase):
|
|||
)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
@FIXME_jacrev_only
|
||||
@jacrev_and_jacfwd
|
||||
def test_multiple_inputs_outputs_pytree_multidim(self, device, jacapi):
|
||||
def f(dct):
|
||||
a = dct['a']
|
||||
|
|
@ -4806,6 +4806,7 @@ class TestCompileTransforms(TestCase):
|
|||
|
||||
actual = wrapper_fn(x, y)
|
||||
expected = torch.compile(wrapper_fn, backend='eager', fullgraph=True)(x, y)
|
||||
fn = torch.compile(wrapper_fn, backend='eager', fullgraph=True)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def wrapper_fn(x, y):
|
||||
|
|
|
|||
|
|
@ -24,6 +24,10 @@ def set_single_level_autograd_function_allowed(allowed: bool) -> None: ...
|
|||
def get_single_level_autograd_function_allowed() -> bool: ...
|
||||
def _unwrap_functional_tensor(tensor: Tensor, reapply_views: bool) -> Tensor: ...
|
||||
def _wrap_functional_tensor(tensor: Tensor, level: int) -> Tensor: ...
|
||||
def _vmap_increment_nesting(batch_size: int, randomness: str) -> int: ...
|
||||
def _vmap_decrement_nesting() -> int: ...
|
||||
def _grad_increment_nesting() -> int: ...
|
||||
def _grad_decrement_nesting() -> int: ...
|
||||
|
||||
# Defined in aten/src/ATen/functorch/Interpreter.h
|
||||
class TransformType(Enum):
|
||||
|
|
|
|||
|
|
@ -2109,7 +2109,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||
speculation_log=speculation_log,
|
||||
)
|
||||
|
||||
self._throw_if_in_vmap()
|
||||
self._throw_if_in_functorch()
|
||||
|
||||
# as soon as we create the tracing context we should keep it active, so any calls
|
||||
# into dynamo apis can rely on finding it
|
||||
|
|
@ -2147,20 +2147,21 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||
if name in f_locals:
|
||||
self._freevars_ids[name] = id(f_locals[name])
|
||||
|
||||
def _throw_if_in_vmap(self):
|
||||
def _throw_if_in_functorch(self):
|
||||
# Fallback to eager in case of a graph break inside vmap
|
||||
eager = torch._dynamo.lookup_backend("eager")
|
||||
compiler_fn = inspect.getattr_static(
|
||||
self.output.compiler_fn, "compiler_fn", self.output.compiler_fn
|
||||
)
|
||||
ci = torch._C._functorch.peek_interpreter_stack()
|
||||
if (
|
||||
ci is not None
|
||||
and ci.key() == torch._C._functorch.TransformType.Vmap
|
||||
and compiler_fn is not eager
|
||||
):
|
||||
# if it reaches here, it means Dynamo failed to inline vmap
|
||||
msg = "torch.vmap(fn) requires the function to be inlined by dynamo"
|
||||
forbidden_keys = (
|
||||
torch._C._functorch.TransformType.Vmap,
|
||||
torch._C._functorch.TransformType.Grad,
|
||||
)
|
||||
if ci is not None and ci.key() in forbidden_keys and compiler_fn is not eager:
|
||||
# if it reaches here, it means Dynamo failed to inline a functorch function
|
||||
name = ci.key().name.lower()
|
||||
msg = f"torch.func.{name}(fn) requires the function to be inlined by dynamo"
|
||||
unimplemented(msg)
|
||||
|
||||
def get_example_value(self, source: Source):
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ from .utils import getfile, hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper
|
|||
|
||||
from .variables import (
|
||||
BuiltinVariable,
|
||||
FunctorchVmapHigherOrderVariable,
|
||||
FunctorchHigherOrderVariable,
|
||||
NestedUserFunctionVariable,
|
||||
SkipFunctionVariable,
|
||||
TorchInGraphFunctionVariable,
|
||||
|
|
@ -155,7 +155,7 @@ manual_torch_name_rule_map = {
|
|||
"torch.resize_as_": SkipFunctionVariable,
|
||||
"torch.resize_as_sparse_": SkipFunctionVariable,
|
||||
"torch.get_default_device": TorchInGraphFunctionVariable,
|
||||
# functorch
|
||||
# functorch/vmap
|
||||
"torch._functorch.vmap._check_int_or_none": UserFunctionVariable,
|
||||
"torch._functorch.vmap._check_out_dims_is_int_or_int_pytree": UserFunctionVariable,
|
||||
"torch._functorch.vmap._check_randomness_arg": UserFunctionVariable,
|
||||
|
|
@ -178,8 +178,20 @@ manual_torch_name_rule_map = {
|
|||
"torch._functorch.vmap.restore_vmap": UserFunctionVariable,
|
||||
"torch._functorch.apis.vmap": UserFunctionVariable,
|
||||
"torch._functorch.vmap.unwrap_batched": UserFunctionVariable,
|
||||
"torch._functorch.vmap.vmap_impl": FunctorchVmapHigherOrderVariable,
|
||||
"torch._functorch.vmap.vmap_impl": FunctorchHigherOrderVariable,
|
||||
"torch._functorch.vmap.wrap_batched": UserFunctionVariable,
|
||||
# functorch/grad
|
||||
"torch._functorch.eager_transforms.grad_impl": FunctorchHigherOrderVariable,
|
||||
"torch._functorch.apis.grad_and_value": UserFunctionVariable,
|
||||
"torch._functorch.eager_transforms._as_tuple": UserFunctionVariable,
|
||||
"torch._functorch.eager_transforms._check_unique_non_empty": UserFunctionVariable,
|
||||
"torch._functorch.eager_transforms._create_differentiable": UserFunctionVariable,
|
||||
"torch._functorch.eager_transforms._slice_argnums": UserFunctionVariable,
|
||||
"torch._functorch.eager_transforms._undo_create_differentiable": UserFunctionVariable,
|
||||
"torch._functorch.eager_transforms._validate_and_wrap_argnum": UserFunctionVariable,
|
||||
"torch._functorch.eager_transforms._validate_and_wrap_argnums": UserFunctionVariable,
|
||||
"torch._functorch.eager_transforms._wrap_all_tensors": UserFunctionVariable,
|
||||
"torch._functorch.eager_transforms._wrap_tensor_for_grad": UserFunctionVariable,
|
||||
"torch._constrain_as_size": UserFunctionVariable,
|
||||
"torch._constrain_as_value": UserFunctionVariable,
|
||||
"torch._tensor._convert": UserFunctionVariable,
|
||||
|
|
@ -187,6 +199,8 @@ manual_torch_name_rule_map = {
|
|||
"torch.backends.mha.get_fastpath_enabled": UserFunctionVariable,
|
||||
"torch._C._functorch._add_batch_dim": TorchInGraphFunctionVariable,
|
||||
"torch._C._functorch._remove_batch_dim": TorchInGraphFunctionVariable,
|
||||
"torch._C._functorch._wrap_for_grad": TorchInGraphFunctionVariable,
|
||||
"torch._C._functorch._unwrap_for_grad": TorchInGraphFunctionVariable,
|
||||
"torch._C._functorch.is_batchedtensor": TorchInGraphFunctionVariable,
|
||||
"torch._dynamo.mark_static": UserFunctionVariable,
|
||||
"torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable,
|
||||
|
|
@ -2132,36 +2146,25 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
|
|||
"torch._functorch.deprecated.vjp",
|
||||
"torch._functorch.deprecated.warn_deprecated",
|
||||
"torch._functorch.eager_transforms._any_differentiable",
|
||||
"torch._functorch.eager_transforms._as_tuple",
|
||||
"torch._functorch.eager_transforms._autograd_grad",
|
||||
"torch._functorch.eager_transforms._check_unique_non_empty",
|
||||
"torch._functorch.eager_transforms._chunked_standard_basis_for_",
|
||||
"torch._functorch.eager_transforms._construct_standard_basis_for",
|
||||
"torch._functorch.eager_transforms._create_differentiable",
|
||||
"torch._functorch.eager_transforms._tensor_requires_grad",
|
||||
"torch._functorch.eager_transforms._is_differentiable",
|
||||
"torch._functorch.eager_transforms._jvp_with_argnums",
|
||||
"torch._functorch.eager_transforms._maybe_unwrap_functional_tensor",
|
||||
"torch._functorch.eager_transforms._maybe_wrap_functional_tensor",
|
||||
"torch._functorch.eager_transforms._replace_args",
|
||||
"torch._functorch.eager_transforms._safe_zero_index",
|
||||
"torch._functorch.eager_transforms._slice_argnums",
|
||||
"torch._functorch.eager_transforms._undo_create_differentiable",
|
||||
"torch._functorch.eager_transforms._unwrap_all_tensors_from_functional",
|
||||
"torch._functorch.eager_transforms._validate_and_wrap_argnum",
|
||||
"torch._functorch.eager_transforms._validate_and_wrap_argnums",
|
||||
"torch._functorch.eager_transforms._vjp_with_argnums",
|
||||
"torch._functorch.eager_transforms._wrap_all_tensors_to_functional",
|
||||
"torch._functorch.eager_transforms._wrap_all_tensors",
|
||||
"torch._functorch.eager_transforms._wrap_tensor_for_grad",
|
||||
"torch._functorch.eager_transforms.assert_flat_tuple_of_tensors",
|
||||
"torch._functorch.eager_transforms.assert_non_empty_list_of_tensors",
|
||||
"torch._functorch.eager_transforms.assert_non_empty_tensor_output",
|
||||
"torch._functorch.eager_transforms.assert_output_is_tensor_or_tensors",
|
||||
"torch._functorch.eager_transforms.enable_inplace_requires_grad",
|
||||
"torch._functorch.eager_transforms.error_if_complex",
|
||||
"torch._functorch.eager_transforms.functionalize",
|
||||
"torch._functorch.eager_transforms.grad_and_value",
|
||||
"torch._functorch.eager_transforms.grad_impl",
|
||||
"torch._functorch.eager_transforms.hessian",
|
||||
"torch._functorch.eager_transforms.jacfwd",
|
||||
"torch._functorch.eager_transforms.jacrev",
|
||||
|
|
@ -3171,6 +3174,7 @@ MOD_INLINELIST = {
|
|||
"torch._dynamo.comptime",
|
||||
"torch._dynamo.polyfill",
|
||||
"torch._functorch.vmap",
|
||||
"torch._functorch.eager_transforms",
|
||||
"torch._inductor.test_operators",
|
||||
"torch.amp.autocast_mode",
|
||||
"torch.ao.nn",
|
||||
|
|
@ -3355,7 +3359,7 @@ def check_verbose(obj, is_inlined_call=False):
|
|||
rule = torch._dynamo.trace_rules.lookup_inner(
|
||||
fi.py_obj, fi.name, fi.filename, is_inlined_call
|
||||
)
|
||||
if rule in [UserFunctionVariable, FunctorchVmapHigherOrderVariable]:
|
||||
if rule in [UserFunctionVariable, FunctorchHigherOrderVariable]:
|
||||
return SkipResult(
|
||||
False,
|
||||
"inlined according trace_rules.lookup",
|
||||
|
|
|
|||
|
|
@ -813,7 +813,9 @@ def skip_frame_if_in_functorch_mode(val: torch.Tensor):
|
|||
|
||||
@contextmanager
|
||||
def preserve_rng_state():
|
||||
with torch.utils._python_dispatch._disable_current_modes():
|
||||
disable_functorch = torch._C._DisableFuncTorch
|
||||
disable_current_modes = torch.utils._python_dispatch._disable_current_modes
|
||||
with disable_current_modes(), disable_functorch():
|
||||
rng_state = torch.clone(torch.random.get_rng_state())
|
||||
skip_frame_if_in_functorch_mode(rng_state)
|
||||
if torch.cuda.is_available():
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from .ctx_manager import (
|
|||
ContextWrappingVariable,
|
||||
DeterministicAlgorithmsVariable,
|
||||
DisabledSavedTensorsHooksVariable,
|
||||
GradIncrementNestingCtxManagerVariable,
|
||||
GradInplaceRequiresGradCtxManagerVariable,
|
||||
GradModeVariable,
|
||||
InferenceModeVariable,
|
||||
StreamContextVariable,
|
||||
|
|
@ -29,7 +31,7 @@ from .functions import (
|
|||
UserMethodVariable,
|
||||
)
|
||||
from .higher_order_ops import (
|
||||
FunctorchVmapHigherOrderVariable,
|
||||
FunctorchHigherOrderVariable,
|
||||
TorchHigherOrderOperatorVariable,
|
||||
)
|
||||
from .iter import (
|
||||
|
|
|
|||
|
|
@ -1510,6 +1510,8 @@ def wrap_fx_proxy_cls(
|
|||
torch._C._functorch._vmap_increment_nesting,
|
||||
torch._C._functorch._vmap_decrement_nesting,
|
||||
torch._functorch.vmap._validate_and_get_batch_size,
|
||||
torch._C._functorch._grad_increment_nesting,
|
||||
torch._C._functorch._grad_decrement_nesting,
|
||||
# some mac builds are missing torch.distributed.get_rank()
|
||||
getattr(torch.distributed, "get_rank", _missing),
|
||||
getattr(torch.distributed, "get_world_size", _missing),
|
||||
|
|
|
|||
|
|
@ -149,6 +149,85 @@ class GenericContextWrappingVariable(ContextWrappingVariable):
|
|||
return x
|
||||
|
||||
|
||||
class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable):
|
||||
"""represents torch grad requries grad"""
|
||||
|
||||
@staticmethod
|
||||
def create(tx, target_values, **kwargs):
|
||||
return GradInplaceRequiresGradCtxManagerVariable(
|
||||
target_values=target_values,
|
||||
initial_values=None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def enter(self, tx):
|
||||
[enabled] = self.target_values
|
||||
self.prev_state = torch._C._functorch.get_inplace_requires_grad_allowed()
|
||||
torch._C._functorch.set_inplace_requires_grad_allowed(enabled)
|
||||
self.set_cleanup_hook(
|
||||
tx,
|
||||
lambda: torch._C._functorch.set_inplace_requires_grad_allowed(
|
||||
self.prev_state
|
||||
),
|
||||
)
|
||||
self.state.proxy = tx.output.create_node(
|
||||
"call_function",
|
||||
torch._C._functorch.set_inplace_requires_grad_allowed,
|
||||
(enabled,),
|
||||
{},
|
||||
)
|
||||
return variables.ConstantVariable.create(None)
|
||||
|
||||
def exit(self, tx, *args):
|
||||
self.state.cleanup()
|
||||
tx.output.create_node(
|
||||
"call_function",
|
||||
torch._C._functorch.set_inplace_requires_grad_allowed,
|
||||
(self.prev_state,),
|
||||
{},
|
||||
)
|
||||
return variables.ConstantVariable.create(None)
|
||||
|
||||
|
||||
class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable):
|
||||
"""represents torch.func.grad increment/decrement nesting"""
|
||||
|
||||
# A guard is needed as the grad level is baked into the torch FX graph
|
||||
# This is fine if grad is only called from within the function
|
||||
# being compiled. But the FX graph may be invalid in the case of a grad
|
||||
# call from eager that calls the compiled function, as the grad levels
|
||||
# may be different.
|
||||
_guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH)
|
||||
|
||||
@staticmethod
|
||||
def create(tx, **kwargs):
|
||||
var = GradIncrementNestingCtxManagerVariable(
|
||||
target_values=None,
|
||||
initial_values=None,
|
||||
**kwargs,
|
||||
)
|
||||
return var
|
||||
|
||||
def enter(self, tx):
|
||||
install_guard(self._guards_singleton)
|
||||
grad_level = torch._C._functorch._grad_increment_nesting()
|
||||
self.set_cleanup_hook(tx, lambda: torch._C._functorch._grad_decrement_nesting())
|
||||
self.state.proxy = tx.output.create_node(
|
||||
"call_function",
|
||||
torch._C._functorch._grad_increment_nesting,
|
||||
(),
|
||||
{},
|
||||
)
|
||||
return variables.ConstantVariable.create(grad_level)
|
||||
|
||||
def exit(self, tx, *args):
|
||||
self.state.cleanup()
|
||||
tx.output.create_node(
|
||||
"call_function", torch._C._functorch._grad_decrement_nesting, (), {}
|
||||
)
|
||||
return variables.ConstantVariable.create(None)
|
||||
|
||||
|
||||
class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable):
|
||||
"""represents torch VMap increment/decrement nesting"""
|
||||
|
||||
|
|
|
|||
|
|
@ -1180,13 +1180,16 @@ class FunctorchGradHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
return TupleVariable([TupleVariable(items), aux])
|
||||
|
||||
|
||||
class FunctorchVmapHigherOrderVariable(UserFunctionVariable):
|
||||
class FunctorchHigherOrderVariable(UserFunctionVariable):
|
||||
def call_function(
|
||||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||||
) -> "VariableTracker":
|
||||
if not torch._dynamo.config.capture_func_transforms:
|
||||
name = self.get_name()
|
||||
assert name in ("grad_impl", "vmap_impl")
|
||||
fn = name.split("_")[0]
|
||||
unimplemented(
|
||||
"torch.func.vmap capture is disabled, "
|
||||
f"torch.func.{fn} capture is disabled, "
|
||||
"it can be turned on by setting "
|
||||
"`torch._dynamo.config.capture_func_transforms=True`"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -43,7 +43,6 @@ from .ctx_manager import (
|
|||
TorchFunctionDisableVariable,
|
||||
)
|
||||
from .distributed import is_constant_pg_functions, is_from_local, ProcessGroupVariable
|
||||
from .higher_order_ops import TorchHigherOrderOperatorVariable
|
||||
from .lists import ListVariable, TupleVariable
|
||||
from .torch_function import can_dispatch_torch_function, dispatch_torch_function
|
||||
|
||||
|
|
@ -55,6 +54,8 @@ supported_ctx_manager_classes = {
|
|||
torch.autograd.profiler.record_function,
|
||||
torch._C.DisableTorchFunctionSubclass,
|
||||
torch._functorch.vmap.vmap_increment_nesting,
|
||||
torch._functorch.eager_transforms.grad_increment_nesting,
|
||||
torch._functorch.eager_transforms.enable_inplace_requires_grad,
|
||||
torch.amp.autocast_mode.autocast,
|
||||
torch.autograd.grad_mode.enable_grad,
|
||||
torch.autograd.grad_mode.inference_mode,
|
||||
|
|
@ -176,6 +177,8 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
|
|||
) -> "VariableTracker":
|
||||
from . import (
|
||||
DisabledSavedTensorsHooksVariable,
|
||||
GradIncrementNestingCtxManagerVariable,
|
||||
GradInplaceRequiresGradCtxManagerVariable,
|
||||
GradModeVariable,
|
||||
InferenceModeVariable,
|
||||
StreamVariable,
|
||||
|
|
@ -241,6 +244,17 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
|
|||
tx,
|
||||
[guard_if_dyn(x) for x in args],
|
||||
)
|
||||
elif self.value is torch._functorch.eager_transforms.grad_increment_nesting:
|
||||
assert len(args) == 0
|
||||
return GradIncrementNestingCtxManagerVariable.create(tx)
|
||||
elif (
|
||||
self.value is torch._functorch.eager_transforms.enable_inplace_requires_grad
|
||||
):
|
||||
assert len(args) == 1
|
||||
return GradInplaceRequiresGradCtxManagerVariable.create(
|
||||
tx,
|
||||
[guard_if_dyn(x) for x in args],
|
||||
)
|
||||
elif self.value is torch.autograd.graph.disable_saved_tensors_hooks:
|
||||
assert len(args) == 1
|
||||
return DisabledSavedTensorsHooksVariable.create(
|
||||
|
|
@ -293,11 +307,6 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
):
|
||||
tx.mark_inconsistent_side_effects()
|
||||
return ConstantVariable.create(tracing_state_functions[self.value])
|
||||
elif self.value in (torch._functorch.eager_transforms.grad_impl,):
|
||||
return TorchHigherOrderOperatorVariable.make(
|
||||
self.value,
|
||||
source=self.source,
|
||||
).call_function(tx, args, kwargs)
|
||||
elif self.value is torch.overrides.get_default_nowrap_functions.__wrapped__:
|
||||
# [Note: __torch_function__] we return empty here because we restrict
|
||||
# the set of functions that we trace __torch_function__ on to
|
||||
|
|
|
|||
|
|
@ -362,3 +362,40 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
|
|||
def wrapper(*args, **kwargs):
|
||||
return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
@exposed_in("torch.func")
|
||||
def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
|
||||
"""
|
||||
Returns a function to compute a tuple of the gradient and primal, or
|
||||
forward, computation.
|
||||
|
||||
Args:
|
||||
func (Callable): A Python function that takes one or more arguments.
|
||||
Must return a single-element Tensor. If specified ``has_aux``
|
||||
equals ``True``, function can return a tuple of single-element
|
||||
Tensor and other auxiliary objects: ``(output, aux)``.
|
||||
argnums (int or Tuple[int]): Specifies arguments to compute gradients
|
||||
with respect to. ``argnums`` can be single integer or tuple of
|
||||
integers. Default: 0.
|
||||
has_aux (bool): Flag indicating that ``func`` returns a tensor and
|
||||
other auxiliary objects: ``(output, aux)``. Default: False.
|
||||
|
||||
Returns:
|
||||
Function to compute a tuple of gradients with respect to its inputs
|
||||
and the forward computation. By default, the output of the function is
|
||||
a tuple of the gradient tensor(s) with respect to the first argument
|
||||
and the primal computation. If specified ``has_aux`` equals
|
||||
``True``, tuple of gradients and tuple of the forward computation with
|
||||
output auxiliary objects is returned. If ``argnums`` is a tuple of
|
||||
integers, a tuple of a tuple of the output gradients with respect to
|
||||
each ``argnums`` value and the forward computation is returned.
|
||||
|
||||
See :func:`grad` for examples
|
||||
"""
|
||||
from torch._functorch import eager_transforms
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return eager_transforms.grad_and_value_impl(func, argnums, has_aux, args, kwargs)
|
||||
return wrapper
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
|
|||
|
||||
def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
|
||||
warn_deprecated('grad_and_value')
|
||||
return _impl.grad_and_value(func, argnums, has_aux)
|
||||
return apis.grad_and_value(func, argnums, has_aux)
|
||||
|
||||
def vjp(func: Callable, *primals, has_aux: bool = False):
|
||||
warn_deprecated('vjp')
|
||||
|
|
@ -110,7 +110,7 @@ def combine_state_for_ensemble(models):
|
|||
|
||||
setup_docs(vmap, apis.vmap, 'torch.vmap')
|
||||
setup_docs(grad, apis.grad)
|
||||
setup_docs(grad_and_value)
|
||||
setup_docs(grad_and_value, apis.grad_and_value)
|
||||
setup_docs(vjp)
|
||||
setup_docs(jvp)
|
||||
setup_docs(jacrev)
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ from torch._C._functorch import (
|
|||
_assert_wrapped_functional,
|
||||
_propagate_functional_input_mutation,
|
||||
set_inplace_requires_grad_allowed,
|
||||
get_inplace_requires_grad_allowed
|
||||
get_inplace_requires_grad_allowed,
|
||||
)
|
||||
from torch._functorch.utils import exposed_in, argnums_t
|
||||
|
||||
|
|
@ -51,7 +51,7 @@ def lazy_dynamo_disable(func):
|
|||
return torch._dynamo.disable(func)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def enable_inplace_requires_grad(enabled=True):
|
||||
def enable_inplace_requires_grad(enabled):
|
||||
prev_state = get_inplace_requires_grad_allowed()
|
||||
set_inplace_requires_grad_allowed(enabled)
|
||||
try:
|
||||
|
|
@ -60,11 +60,16 @@ def enable_inplace_requires_grad(enabled=True):
|
|||
set_inplace_requires_grad_allowed(prev_state)
|
||||
|
||||
|
||||
def _tensor_requires_grad(x):
|
||||
# avoid graph-break on x.requires_grad_()
|
||||
# https://github.com/pytorch/pytorch/pull/110053
|
||||
return x.requires_grad_()
|
||||
|
||||
def _create_differentiable(inps, level=None):
|
||||
def create_differentiable(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
with enable_inplace_requires_grad():
|
||||
return x.requires_grad_()
|
||||
with enable_inplace_requires_grad(True):
|
||||
return _tensor_requires_grad(x)
|
||||
raise ValueError(f'Thing passed to transform API must be Tensor, '
|
||||
f'got {type(x)}')
|
||||
return tree_map(create_differentiable, inps)
|
||||
|
|
@ -277,6 +282,15 @@ def vjp(func: Callable, *primals, has_aux: bool = False):
|
|||
return _vjp_with_argnums(func, *primals, has_aux=has_aux)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def grad_increment_nesting():
|
||||
try:
|
||||
grad_level = _grad_increment_nesting()
|
||||
yield grad_level
|
||||
finally:
|
||||
_grad_decrement_nesting()
|
||||
|
||||
|
||||
@doesnt_support_saved_tensors_hooks
|
||||
def _vjp_with_argnums(func: Callable, *primals, argnums: Optional[argnums_t] = None, has_aux: bool = False):
|
||||
# This is the same function as vjp but also accepts an argnums argument
|
||||
|
|
@ -1213,88 +1227,54 @@ def hessian(func, argnums=0):
|
|||
return jacfwd(jacrev(func, argnums), argnums)
|
||||
|
||||
|
||||
@exposed_in("torch.func")
|
||||
def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
|
||||
"""
|
||||
Returns a function to compute a tuple of the gradient and primal, or
|
||||
forward, computation.
|
||||
|
||||
Args:
|
||||
func (Callable): A Python function that takes one or more arguments.
|
||||
Must return a single-element Tensor. If specified ``has_aux``
|
||||
equals ``True``, function can return a tuple of single-element
|
||||
Tensor and other auxiliary objects: ``(output, aux)``.
|
||||
argnums (int or Tuple[int]): Specifies arguments to compute gradients
|
||||
with respect to. ``argnums`` can be single integer or tuple of
|
||||
integers. Default: 0.
|
||||
has_aux (bool): Flag indicating that ``func`` returns a tensor and
|
||||
other auxiliary objects: ``(output, aux)``. Default: False.
|
||||
|
||||
Returns:
|
||||
Function to compute a tuple of gradients with respect to its inputs
|
||||
and the forward computation. By default, the output of the function is
|
||||
a tuple of the gradient tensor(s) with respect to the first argument
|
||||
and the primal computation. If specified ``has_aux`` equals
|
||||
``True``, tuple of gradients and tuple of the forward computation with
|
||||
output auxiliary objects is returned. If ``argnums`` is a tuple of
|
||||
integers, a tuple of a tuple of the output gradients with respect to
|
||||
each ``argnums`` value and the forward computation is returned.
|
||||
|
||||
See :func:`grad` for examples
|
||||
"""
|
||||
@doesnt_support_saved_tensors_hooks
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
level = _grad_increment_nesting()
|
||||
try:
|
||||
output, aux, grad_input = None, None, None
|
||||
# See NOTE [grad and vjp interaction with no_grad]
|
||||
with torch.enable_grad():
|
||||
args = _wrap_all_tensors(args, level)
|
||||
kwargs = _wrap_all_tensors(kwargs, level)
|
||||
diff_args = _slice_argnums(args, argnums, as_tuple=False)
|
||||
tree_map_(partial(_create_differentiable, level=level), diff_args)
|
||||
|
||||
output = func(*args, **kwargs)
|
||||
if has_aux:
|
||||
if not (isinstance(output, tuple) and len(output) == 2):
|
||||
raise RuntimeError(
|
||||
"grad_and_value(f)(*args): output of function f should be a tuple: (output, aux) "
|
||||
"if has_aux is True"
|
||||
)
|
||||
output, aux = output
|
||||
|
||||
if not isinstance(output, torch.Tensor):
|
||||
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
|
||||
f'to return a Tensor, got {type(output)}')
|
||||
if output.dim() != 0:
|
||||
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
|
||||
'to return a scalar Tensor, got tensor with '
|
||||
f'{output.dim()} dims. Maybe you wanted to '
|
||||
'use the vjp or jacrev APIs instead?')
|
||||
|
||||
flat_diff_args, spec = tree_flatten(diff_args)
|
||||
|
||||
# NB: need create_graph so that backward pass isn't run in no_grad mode
|
||||
flat_outputs = _as_tuple(output)
|
||||
flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True)
|
||||
grad_input = tree_unflatten(flat_grad_input, spec)
|
||||
|
||||
grad_input = _undo_create_differentiable(grad_input, level)
|
||||
output = _undo_create_differentiable(output, level)
|
||||
if aux is not None:
|
||||
aux = _undo_create_differentiable(aux, level)
|
||||
@doesnt_support_saved_tensors_hooks
|
||||
def grad_and_value_impl(func, argnums, has_aux, args, kwargs) -> Callable:
|
||||
with grad_increment_nesting() as level:
|
||||
output, aux, grad_input = None, None, None
|
||||
# See NOTE [grad and vjp interaction with no_grad]
|
||||
with torch.enable_grad():
|
||||
args = _wrap_all_tensors(args, level)
|
||||
kwargs = _wrap_all_tensors(kwargs, level)
|
||||
diff_args = _slice_argnums(args, argnums, as_tuple=False)
|
||||
tree_map_(partial(_create_differentiable, level=level), diff_args)
|
||||
|
||||
output = func(*args, **kwargs)
|
||||
if has_aux:
|
||||
return grad_input, (output, aux)
|
||||
return grad_input, output
|
||||
finally:
|
||||
_grad_decrement_nesting()
|
||||
return wrapper
|
||||
if not (isinstance(output, tuple) and len(output) == 2):
|
||||
raise RuntimeError(
|
||||
"grad_and_value(f)(*args): output of function f should be a tuple: (output, aux) "
|
||||
"if has_aux is True"
|
||||
)
|
||||
output, aux = output
|
||||
|
||||
if not isinstance(output, torch.Tensor):
|
||||
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
|
||||
f'to return a Tensor, got {type(output)}')
|
||||
if output.dim() != 0:
|
||||
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args) '
|
||||
'to return a scalar Tensor, got tensor with '
|
||||
f'{output.dim()} dims. Maybe you wanted to '
|
||||
'use the vjp or jacrev APIs instead?')
|
||||
|
||||
flat_diff_args, spec = tree_flatten(diff_args)
|
||||
|
||||
# NB: need create_graph so that backward pass isn't run in no_grad mode
|
||||
flat_outputs = _as_tuple(output)
|
||||
flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True)
|
||||
grad_input = tree_unflatten(flat_grad_input, spec)
|
||||
|
||||
grad_input = _undo_create_differentiable(grad_input, level)
|
||||
output = _undo_create_differentiable(output, level)
|
||||
if has_aux:
|
||||
aux = _undo_create_differentiable(aux, level)
|
||||
|
||||
if has_aux:
|
||||
return grad_input, (output, aux)
|
||||
return grad_input, output
|
||||
|
||||
|
||||
def grad_impl(func: Callable, argnums: argnums_t, has_aux: bool, args, kwargs):
|
||||
func = lazy_dynamo_disable(func)
|
||||
results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
|
||||
results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)
|
||||
if has_aux:
|
||||
grad, (_, aux) = results
|
||||
return grad, aux
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from weakref import ReferenceType
|
|||
import torch
|
||||
import torch._custom_op
|
||||
import torch._logging
|
||||
from torch._C._functorch import is_functorch_wrapped_tensor
|
||||
|
||||
from torch._guards import Source
|
||||
from torch._ops import OpOverload
|
||||
|
|
@ -124,7 +125,7 @@ def is_fake(x):
|
|||
reapply_views = torch._C._functionalization_reapply_views_tls()
|
||||
unwrapped = torch._C._functorch._unwrap_functional_tensor(x, reapply_views)
|
||||
return is_fake(unwrapped)
|
||||
elif isinstance(x, torch.Tensor) and torch._C._functorch.is_batchedtensor(x):
|
||||
elif isinstance(x, torch.Tensor) and is_functorch_wrapped_tensor(x):
|
||||
unwrapped = torch._C._functorch.get_unwrapped(x)
|
||||
return is_fake(unwrapped)
|
||||
return False
|
||||
|
|
@ -145,7 +146,7 @@ def maybe_get_fake_mode(t):
|
|||
reapply_views = torch._C._functionalization_reapply_views_tls()
|
||||
unwrapped = torch._C._functorch._unwrap_functional_tensor(t, reapply_views)
|
||||
return maybe_get_fake_mode(unwrapped)
|
||||
elif isinstance(t, torch.Tensor) and torch._C._functorch.is_batchedtensor(t):
|
||||
elif isinstance(t, torch.Tensor) and is_functorch_wrapped_tensor(t):
|
||||
unwrapped = torch._C._functorch.get_unwrapped(t)
|
||||
return maybe_get_fake_mode(unwrapped)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from torch._C._functorch import (
|
|||
current_level,
|
||||
get_unwrapped,
|
||||
is_batchedtensor,
|
||||
is_functorch_wrapped_tensor,
|
||||
is_gradtrackingtensor,
|
||||
maybe_get_bdim,
|
||||
maybe_get_level,
|
||||
peek_interpreter_stack,
|
||||
|
|
@ -133,7 +135,7 @@ class MetaConverter:
|
|||
# hold a weak ref to self, otherwise it will be kept alive
|
||||
# by the del_ten closure
|
||||
self_weak_ref = weakref.ref(self)
|
||||
if t.is_sparse or t.is_mkldnn or is_batchedtensor(t):
|
||||
if t.is_sparse or t.is_mkldnn or is_functorch_wrapped_tensor(t):
|
||||
weak_st = None
|
||||
else:
|
||||
weak_st = StorageWeakRef(t._typed_storage())
|
||||
|
|
@ -327,16 +329,36 @@ class MetaConverter:
|
|||
if t.requires_grad and not is_leaf:
|
||||
with torch.enable_grad():
|
||||
r = r.clone()
|
||||
elif is_batchedtensor(t):
|
||||
# Wraps a BatchedTensor in a FakeTensor
|
||||
elif is_functorch_wrapped_tensor(t):
|
||||
if t._is_view():
|
||||
from torch._dynamo.exc import unimplemented
|
||||
|
||||
unimplemented(
|
||||
"view functorch tensors are not supported by meta conversion"
|
||||
)
|
||||
|
||||
# Wraps a functorch tensor class (BatchedTensor, GradTrackingTensor)
|
||||
# in a FakeTensor
|
||||
def _to_fake_tensor(t):
|
||||
if is_batchedtensor(t):
|
||||
ft = _to_fake_tensor(get_unwrapped(t))
|
||||
lvl = maybe_get_level(t)
|
||||
bdim = maybe_get_bdim(t)
|
||||
r = _add_batch_dim(ft, bdim, lvl)
|
||||
elif is_gradtrackingtensor(t):
|
||||
disable_functorch = torch._C._DisableFuncTorch
|
||||
with disable_functorch():
|
||||
ft = _to_fake_tensor(get_unwrapped(t))
|
||||
lvl = torch._C._functorch.maybe_get_level(t)
|
||||
r = torch._C._functorch._wrap_for_grad(ft, lvl)
|
||||
|
||||
is_leaf = safe_is_leaf(t)
|
||||
if t.requires_grad and safe_is_leaf(r):
|
||||
r.requires_grad = True
|
||||
elif t.requires_grad and not is_leaf:
|
||||
with torch.enable_grad():
|
||||
r = r.clone()
|
||||
else:
|
||||
# regular tensor
|
||||
sizes = t.size()
|
||||
strides = t.stride()
|
||||
r = callback(
|
||||
|
|
@ -537,6 +559,7 @@ class MetaConverter:
|
|||
device="meta",
|
||||
)
|
||||
)
|
||||
|
||||
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
|
||||
if t.requires_grad:
|
||||
r.requires_grad = t.requires_grad
|
||||
|
|
@ -549,8 +572,8 @@ class MetaConverter:
|
|||
r = r.clone(memory_format=torch.preserve_format)
|
||||
|
||||
# Graph-Break for wrapped tensors
|
||||
if not is_batchedtensor(
|
||||
t
|
||||
if not (
|
||||
is_batchedtensor(t) or is_gradtrackingtensor(t)
|
||||
) and torch._C._functorch.is_functorch_wrapped_tensor(t):
|
||||
return NotImplemented
|
||||
|
||||
|
|
@ -701,13 +724,16 @@ class MetaConverter:
|
|||
return NotImplemented
|
||||
else:
|
||||
self.hit += 1
|
||||
r = self.meta_tensor(
|
||||
t,
|
||||
shape_env=shape_env,
|
||||
callback=callback,
|
||||
source=source,
|
||||
symbolic_context=symbolic_context,
|
||||
)
|
||||
|
||||
disable_functorch = torch._C._DisableFuncTorch
|
||||
with disable_functorch():
|
||||
r = self.meta_tensor(
|
||||
t,
|
||||
shape_env=shape_env,
|
||||
callback=callback,
|
||||
source=source,
|
||||
symbolic_context=symbolic_context,
|
||||
)
|
||||
if type(t) is torch.nn.Parameter:
|
||||
# NB: Cannot directly use Parameter constructor
|
||||
# because that would force a detach, not desirable
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from torch._functorch.eager_transforms import (
|
||||
grad_and_value,
|
||||
vjp,
|
||||
jvp,
|
||||
jacrev,
|
||||
|
|
@ -8,7 +7,7 @@ from torch._functorch.eager_transforms import (
|
|||
functionalize,
|
||||
linearize
|
||||
)
|
||||
from torch._functorch.apis import grad
|
||||
from torch._functorch.apis import grad, grad_and_value
|
||||
from torch._functorch.functional_call import functional_call, stack_module_state
|
||||
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
|
||||
from torch._functorch.apis import vmap
|
||||
|
|
|
|||
|
|
@ -5018,6 +5018,7 @@ def munge_exc(e, *, suppress_suffix=True, suppress_prefix=True, file=None, skip=
|
|||
|
||||
s = re.sub(r' File "([^"]+)", line \d+, in (.+)\n .+\n( +[~^]+ *\n)?', repl_frame, s)
|
||||
s = re.sub(r"line \d+", "line N", s)
|
||||
s = re.sub(r".py:\d+", ".py:N", s)
|
||||
s = re.sub(file, os.path.basename(file), s)
|
||||
s = re.sub(os.path.join(os.path.dirname(torch.__file__), ""), "", s)
|
||||
s = re.sub(r"\\", "/", s) # for Windows
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user