Fix composable checkpoint(use_reentrant=True) with multi args (#103590)

The `_ModuleHookCheckpointFunction.backward()` should take in `*output_grads` instead of `output_grads`. Otherwise, we may see an error like:
```
TypeError: backward() takes 2 positional arguments but 5 were given
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103590
Approved by: https://github.com/rohan-varma, https://github.com/fduwjj, https://github.com/fegin
This commit is contained in:
Andrew Gu 2023-06-14 13:20:10 +00:00 committed by PyTorch MergeBot
parent c2952e8be9
commit 2eea3cb19d
2 changed files with 51 additions and 1 deletions

View File

@ -4,6 +4,7 @@ import unittest
from collections import deque from collections import deque
from contextlib import ContextDecorator from contextlib import ContextDecorator
from copy import deepcopy from copy import deepcopy
from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -66,6 +67,32 @@ class RandomModel(nn.Module):
return torch.matmul(x, y) return torch.matmul(x, y)
class MultiOutputModel(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.w1 = nn.Parameter(torch.randn((100, 100), device=device))
self.w2 = nn.Parameter(torch.randn((100, 100), device=device))
def forward(self, x: torch.Tensor) -> torch.Tensor:
z = x @ self.w1
z = nn.functional.relu(z)
z = z @ self.w2
return z.sin(), z.cos()
class MultiInputModel(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.w = nn.Parameter(torch.randn((100, 100), device=device))
def forward(self, xs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
assert len(xs) == 2, f"Expects 2 args but got {len(xs)}"
x, y = xs
z = x + y
z = z @ self.w
return nn.functional.relu(z)
class TestCheckpoint(TestCase): class TestCheckpoint(TestCase):
def _get_graph_size(self, out: torch.Tensor) -> int: def _get_graph_size(self, out: torch.Tensor) -> int:
q = deque([out.grad_fn]) q = deque([out.grad_fn])
@ -143,6 +170,29 @@ class TestCheckpoint(TestCase):
for p1, p2 in zip(net1.parameters(), net2.parameters()): for p1, p2 in zip(net1.parameters(), net2.parameters()):
self.assertEqual(p1.grad, p2.grad) self.assertEqual(p1.grad, p2.grad)
@parametrize("use_reentrant", [True, False])
def test_multi_args(self, use_reentrant: bool):
"""
Tests checkpoint for modules with multiple output args and hence
multiple backward function input args.
"""
device = torch.device("cpu")
net1 = nn.Sequential(
MultiOutputModel(device),
MultiInputModel(device),
MultiOutputModel(device),
MultiInputModel(device),
)
net2 = deepcopy(net1)
checkpoint(net2[0], use_reentrant=use_reentrant)
checkpoint(net2[2], use_reentrant=use_reentrant)
x1 = torch.randn(20, 100, requires_grad=True)
x2 = x1.clone()
net1(x1).sum().backward()
net2(x2).sum().backward()
for p1, p2 in zip(net1.parameters(), net2.parameters()):
self.assertEqual(p1.grad, p2.grad)
instantiate_parametrized_tests(TestCheckpoint) instantiate_parametrized_tests(TestCheckpoint)

View File

@ -47,7 +47,7 @@ class _ModuleHookCheckpointFunction(torch.autograd.Function):
return output return output
@staticmethod @staticmethod
def backward(ctx, output_grads: Tuple[Optional[torch.Tensor]]) -> Any: # type: ignore[override] def backward(ctx, *output_grads: Tuple[Optional[torch.Tensor]]) -> Any: # type: ignore[override]
if not torch.autograd._is_checkpoint_valid(): if not torch.autograd._is_checkpoint_valid():
raise RuntimeError( raise RuntimeError(
"Checkpointing is not compatible with .grad() or when an " "Checkpointing is not compatible with .grad() or when an "