mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix composable checkpoint(use_reentrant=True) with multi args (#103590)
The `_ModuleHookCheckpointFunction.backward()` should take in `*output_grads` instead of `output_grads`. Otherwise, we may see an error like: ``` TypeError: backward() takes 2 positional arguments but 5 were given ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/103590 Approved by: https://github.com/rohan-varma, https://github.com/fduwjj, https://github.com/fegin
This commit is contained in:
parent
c2952e8be9
commit
2eea3cb19d
|
|
@ -4,6 +4,7 @@ import unittest
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from contextlib import ContextDecorator
|
from contextlib import ContextDecorator
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -66,6 +67,32 @@ class RandomModel(nn.Module):
|
||||||
return torch.matmul(x, y)
|
return torch.matmul(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiOutputModel(nn.Module):
|
||||||
|
def __init__(self, device: torch.device):
|
||||||
|
super().__init__()
|
||||||
|
self.w1 = nn.Parameter(torch.randn((100, 100), device=device))
|
||||||
|
self.w2 = nn.Parameter(torch.randn((100, 100), device=device))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
z = x @ self.w1
|
||||||
|
z = nn.functional.relu(z)
|
||||||
|
z = z @ self.w2
|
||||||
|
return z.sin(), z.cos()
|
||||||
|
|
||||||
|
|
||||||
|
class MultiInputModel(nn.Module):
|
||||||
|
def __init__(self, device: torch.device):
|
||||||
|
super().__init__()
|
||||||
|
self.w = nn.Parameter(torch.randn((100, 100), device=device))
|
||||||
|
|
||||||
|
def forward(self, xs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
||||||
|
assert len(xs) == 2, f"Expects 2 args but got {len(xs)}"
|
||||||
|
x, y = xs
|
||||||
|
z = x + y
|
||||||
|
z = z @ self.w
|
||||||
|
return nn.functional.relu(z)
|
||||||
|
|
||||||
|
|
||||||
class TestCheckpoint(TestCase):
|
class TestCheckpoint(TestCase):
|
||||||
def _get_graph_size(self, out: torch.Tensor) -> int:
|
def _get_graph_size(self, out: torch.Tensor) -> int:
|
||||||
q = deque([out.grad_fn])
|
q = deque([out.grad_fn])
|
||||||
|
|
@ -143,6 +170,29 @@ class TestCheckpoint(TestCase):
|
||||||
for p1, p2 in zip(net1.parameters(), net2.parameters()):
|
for p1, p2 in zip(net1.parameters(), net2.parameters()):
|
||||||
self.assertEqual(p1.grad, p2.grad)
|
self.assertEqual(p1.grad, p2.grad)
|
||||||
|
|
||||||
|
@parametrize("use_reentrant", [True, False])
|
||||||
|
def test_multi_args(self, use_reentrant: bool):
|
||||||
|
"""
|
||||||
|
Tests checkpoint for modules with multiple output args and hence
|
||||||
|
multiple backward function input args.
|
||||||
|
"""
|
||||||
|
device = torch.device("cpu")
|
||||||
|
net1 = nn.Sequential(
|
||||||
|
MultiOutputModel(device),
|
||||||
|
MultiInputModel(device),
|
||||||
|
MultiOutputModel(device),
|
||||||
|
MultiInputModel(device),
|
||||||
|
)
|
||||||
|
net2 = deepcopy(net1)
|
||||||
|
checkpoint(net2[0], use_reentrant=use_reentrant)
|
||||||
|
checkpoint(net2[2], use_reentrant=use_reentrant)
|
||||||
|
x1 = torch.randn(20, 100, requires_grad=True)
|
||||||
|
x2 = x1.clone()
|
||||||
|
net1(x1).sum().backward()
|
||||||
|
net2(x2).sum().backward()
|
||||||
|
for p1, p2 in zip(net1.parameters(), net2.parameters()):
|
||||||
|
self.assertEqual(p1.grad, p2.grad)
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(TestCheckpoint)
|
instantiate_parametrized_tests(TestCheckpoint)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ class _ModuleHookCheckpointFunction(torch.autograd.Function):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, output_grads: Tuple[Optional[torch.Tensor]]) -> Any: # type: ignore[override]
|
def backward(ctx, *output_grads: Tuple[Optional[torch.Tensor]]) -> Any: # type: ignore[override]
|
||||||
if not torch.autograd._is_checkpoint_valid():
|
if not torch.autograd._is_checkpoint_valid():
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Checkpointing is not compatible with .grad() or when an "
|
"Checkpointing is not compatible with .grad() or when an "
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user