mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix composable checkpoint(use_reentrant=True) with multi args (#103590)
The `_ModuleHookCheckpointFunction.backward()` should take in `*output_grads` instead of `output_grads`. Otherwise, we may see an error like: ``` TypeError: backward() takes 2 positional arguments but 5 were given ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/103590 Approved by: https://github.com/rohan-varma, https://github.com/fduwjj, https://github.com/fegin
This commit is contained in:
parent
c2952e8be9
commit
2eea3cb19d
|
|
@ -4,6 +4,7 @@ import unittest
|
|||
from collections import deque
|
||||
from contextlib import ContextDecorator
|
||||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -66,6 +67,32 @@ class RandomModel(nn.Module):
|
|||
return torch.matmul(x, y)
|
||||
|
||||
|
||||
class MultiOutputModel(nn.Module):
|
||||
def __init__(self, device: torch.device):
|
||||
super().__init__()
|
||||
self.w1 = nn.Parameter(torch.randn((100, 100), device=device))
|
||||
self.w2 = nn.Parameter(torch.randn((100, 100), device=device))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
z = x @ self.w1
|
||||
z = nn.functional.relu(z)
|
||||
z = z @ self.w2
|
||||
return z.sin(), z.cos()
|
||||
|
||||
|
||||
class MultiInputModel(nn.Module):
|
||||
def __init__(self, device: torch.device):
|
||||
super().__init__()
|
||||
self.w = nn.Parameter(torch.randn((100, 100), device=device))
|
||||
|
||||
def forward(self, xs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
||||
assert len(xs) == 2, f"Expects 2 args but got {len(xs)}"
|
||||
x, y = xs
|
||||
z = x + y
|
||||
z = z @ self.w
|
||||
return nn.functional.relu(z)
|
||||
|
||||
|
||||
class TestCheckpoint(TestCase):
|
||||
def _get_graph_size(self, out: torch.Tensor) -> int:
|
||||
q = deque([out.grad_fn])
|
||||
|
|
@ -143,6 +170,29 @@ class TestCheckpoint(TestCase):
|
|||
for p1, p2 in zip(net1.parameters(), net2.parameters()):
|
||||
self.assertEqual(p1.grad, p2.grad)
|
||||
|
||||
@parametrize("use_reentrant", [True, False])
|
||||
def test_multi_args(self, use_reentrant: bool):
|
||||
"""
|
||||
Tests checkpoint for modules with multiple output args and hence
|
||||
multiple backward function input args.
|
||||
"""
|
||||
device = torch.device("cpu")
|
||||
net1 = nn.Sequential(
|
||||
MultiOutputModel(device),
|
||||
MultiInputModel(device),
|
||||
MultiOutputModel(device),
|
||||
MultiInputModel(device),
|
||||
)
|
||||
net2 = deepcopy(net1)
|
||||
checkpoint(net2[0], use_reentrant=use_reentrant)
|
||||
checkpoint(net2[2], use_reentrant=use_reentrant)
|
||||
x1 = torch.randn(20, 100, requires_grad=True)
|
||||
x2 = x1.clone()
|
||||
net1(x1).sum().backward()
|
||||
net2(x2).sum().backward()
|
||||
for p1, p2 in zip(net1.parameters(), net2.parameters()):
|
||||
self.assertEqual(p1.grad, p2.grad)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestCheckpoint)
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class _ModuleHookCheckpointFunction(torch.autograd.Function):
|
|||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads: Tuple[Optional[torch.Tensor]]) -> Any: # type: ignore[override]
|
||||
def backward(ctx, *output_grads: Tuple[Optional[torch.Tensor]]) -> Any: # type: ignore[override]
|
||||
if not torch.autograd._is_checkpoint_valid():
|
||||
raise RuntimeError(
|
||||
"Checkpointing is not compatible with .grad() or when an "
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user