From 2eea3cb19d48d80fa81f83ce04da6ac35a5caedb Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Wed, 14 Jun 2023 13:20:10 +0000 Subject: [PATCH] Fix composable `checkpoint(use_reentrant=True)` with multi args (#103590) The `_ModuleHookCheckpointFunction.backward()` should take in `*output_grads` instead of `output_grads`. Otherwise, we may see an error like: ``` TypeError: backward() takes 2 positional arguments but 5 were given ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/103590 Approved by: https://github.com/rohan-varma, https://github.com/fduwjj, https://github.com/fegin --- .../_composable/test_checkpoint.py | 50 +++++++++++++++++++ .../_composable/checkpoint_activation.py | 2 +- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/test/distributed/_composable/test_checkpoint.py b/test/distributed/_composable/test_checkpoint.py index b5f2267fd5b..0a49ac4ae50 100644 --- a/test/distributed/_composable/test_checkpoint.py +++ b/test/distributed/_composable/test_checkpoint.py @@ -4,6 +4,7 @@ import unittest from collections import deque from contextlib import ContextDecorator from copy import deepcopy +from typing import Tuple import torch import torch.nn as nn @@ -66,6 +67,32 @@ class RandomModel(nn.Module): return torch.matmul(x, y) +class MultiOutputModel(nn.Module): + def __init__(self, device: torch.device): + super().__init__() + self.w1 = nn.Parameter(torch.randn((100, 100), device=device)) + self.w2 = nn.Parameter(torch.randn((100, 100), device=device)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + z = x @ self.w1 + z = nn.functional.relu(z) + z = z @ self.w2 + return z.sin(), z.cos() + + +class MultiInputModel(nn.Module): + def __init__(self, device: torch.device): + super().__init__() + self.w = nn.Parameter(torch.randn((100, 100), device=device)) + + def forward(self, xs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + assert len(xs) == 2, f"Expects 2 args but got {len(xs)}" + x, y = xs + z = x + y + z = z @ self.w + return nn.functional.relu(z) + + class TestCheckpoint(TestCase): def _get_graph_size(self, out: torch.Tensor) -> int: q = deque([out.grad_fn]) @@ -143,6 +170,29 @@ class TestCheckpoint(TestCase): for p1, p2 in zip(net1.parameters(), net2.parameters()): self.assertEqual(p1.grad, p2.grad) + @parametrize("use_reentrant", [True, False]) + def test_multi_args(self, use_reentrant: bool): + """ + Tests checkpoint for modules with multiple output args and hence + multiple backward function input args. + """ + device = torch.device("cpu") + net1 = nn.Sequential( + MultiOutputModel(device), + MultiInputModel(device), + MultiOutputModel(device), + MultiInputModel(device), + ) + net2 = deepcopy(net1) + checkpoint(net2[0], use_reentrant=use_reentrant) + checkpoint(net2[2], use_reentrant=use_reentrant) + x1 = torch.randn(20, 100, requires_grad=True) + x2 = x1.clone() + net1(x1).sum().backward() + net2(x2).sum().backward() + for p1, p2 in zip(net1.parameters(), net2.parameters()): + self.assertEqual(p1.grad, p2.grad) + instantiate_parametrized_tests(TestCheckpoint) diff --git a/torch/distributed/_composable/checkpoint_activation.py b/torch/distributed/_composable/checkpoint_activation.py index 0bbc93d9345..2533fe37f93 100644 --- a/torch/distributed/_composable/checkpoint_activation.py +++ b/torch/distributed/_composable/checkpoint_activation.py @@ -47,7 +47,7 @@ class _ModuleHookCheckpointFunction(torch.autograd.Function): return output @staticmethod - def backward(ctx, output_grads: Tuple[Optional[torch.Tensor]]) -> Any: # type: ignore[override] + def backward(ctx, *output_grads: Tuple[Optional[torch.Tensor]]) -> Any: # type: ignore[override] if not torch.autograd._is_checkpoint_valid(): raise RuntimeError( "Checkpointing is not compatible with .grad() or when an "