Fix composable checkpoint(use_reentrant=True) with multi args (#103590)

The `_ModuleHookCheckpointFunction.backward()` should take in `*output_grads` instead of `output_grads`. Otherwise, we may see an error like:
```
TypeError: backward() takes 2 positional arguments but 5 were given
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103590
Approved by: https://github.com/rohan-varma, https://github.com/fduwjj, https://github.com/fegin
This commit is contained in:
Andrew Gu 2023-06-14 13:20:10 +00:00 committed by PyTorch MergeBot
parent c2952e8be9
commit 2eea3cb19d
2 changed files with 51 additions and 1 deletions

View File

@ -4,6 +4,7 @@ import unittest
from collections import deque
from contextlib import ContextDecorator
from copy import deepcopy
from typing import Tuple
import torch
import torch.nn as nn
@ -66,6 +67,32 @@ class RandomModel(nn.Module):
return torch.matmul(x, y)
class MultiOutputModel(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.w1 = nn.Parameter(torch.randn((100, 100), device=device))
self.w2 = nn.Parameter(torch.randn((100, 100), device=device))
def forward(self, x: torch.Tensor) -> torch.Tensor:
z = x @ self.w1
z = nn.functional.relu(z)
z = z @ self.w2
return z.sin(), z.cos()
class MultiInputModel(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.w = nn.Parameter(torch.randn((100, 100), device=device))
def forward(self, xs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
assert len(xs) == 2, f"Expects 2 args but got {len(xs)}"
x, y = xs
z = x + y
z = z @ self.w
return nn.functional.relu(z)
class TestCheckpoint(TestCase):
def _get_graph_size(self, out: torch.Tensor) -> int:
q = deque([out.grad_fn])
@ -143,6 +170,29 @@ class TestCheckpoint(TestCase):
for p1, p2 in zip(net1.parameters(), net2.parameters()):
self.assertEqual(p1.grad, p2.grad)
@parametrize("use_reentrant", [True, False])
def test_multi_args(self, use_reentrant: bool):
"""
Tests checkpoint for modules with multiple output args and hence
multiple backward function input args.
"""
device = torch.device("cpu")
net1 = nn.Sequential(
MultiOutputModel(device),
MultiInputModel(device),
MultiOutputModel(device),
MultiInputModel(device),
)
net2 = deepcopy(net1)
checkpoint(net2[0], use_reentrant=use_reentrant)
checkpoint(net2[2], use_reentrant=use_reentrant)
x1 = torch.randn(20, 100, requires_grad=True)
x2 = x1.clone()
net1(x1).sum().backward()
net2(x2).sum().backward()
for p1, p2 in zip(net1.parameters(), net2.parameters()):
self.assertEqual(p1.grad, p2.grad)
instantiate_parametrized_tests(TestCheckpoint)

View File

@ -47,7 +47,7 @@ class _ModuleHookCheckpointFunction(torch.autograd.Function):
return output
@staticmethod
def backward(ctx, output_grads: Tuple[Optional[torch.Tensor]]) -> Any: # type: ignore[override]
def backward(ctx, *output_grads: Tuple[Optional[torch.Tensor]]) -> Any: # type: ignore[override]
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad() or when an "