diff --git a/test/distributed/_composable/test_checkpoint.py b/test/distributed/_composable/test_checkpoint.py index b5f2267fd5b..0a49ac4ae50 100644 --- a/test/distributed/_composable/test_checkpoint.py +++ b/test/distributed/_composable/test_checkpoint.py @@ -4,6 +4,7 @@ import unittest from collections import deque from contextlib import ContextDecorator from copy import deepcopy +from typing import Tuple import torch import torch.nn as nn @@ -66,6 +67,32 @@ class RandomModel(nn.Module): return torch.matmul(x, y) +class MultiOutputModel(nn.Module): + def __init__(self, device: torch.device): + super().__init__() + self.w1 = nn.Parameter(torch.randn((100, 100), device=device)) + self.w2 = nn.Parameter(torch.randn((100, 100), device=device)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + z = x @ self.w1 + z = nn.functional.relu(z) + z = z @ self.w2 + return z.sin(), z.cos() + + +class MultiInputModel(nn.Module): + def __init__(self, device: torch.device): + super().__init__() + self.w = nn.Parameter(torch.randn((100, 100), device=device)) + + def forward(self, xs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + assert len(xs) == 2, f"Expects 2 args but got {len(xs)}" + x, y = xs + z = x + y + z = z @ self.w + return nn.functional.relu(z) + + class TestCheckpoint(TestCase): def _get_graph_size(self, out: torch.Tensor) -> int: q = deque([out.grad_fn]) @@ -143,6 +170,29 @@ class TestCheckpoint(TestCase): for p1, p2 in zip(net1.parameters(), net2.parameters()): self.assertEqual(p1.grad, p2.grad) + @parametrize("use_reentrant", [True, False]) + def test_multi_args(self, use_reentrant: bool): + """ + Tests checkpoint for modules with multiple output args and hence + multiple backward function input args. + """ + device = torch.device("cpu") + net1 = nn.Sequential( + MultiOutputModel(device), + MultiInputModel(device), + MultiOutputModel(device), + MultiInputModel(device), + ) + net2 = deepcopy(net1) + checkpoint(net2[0], use_reentrant=use_reentrant) + checkpoint(net2[2], use_reentrant=use_reentrant) + x1 = torch.randn(20, 100, requires_grad=True) + x2 = x1.clone() + net1(x1).sum().backward() + net2(x2).sum().backward() + for p1, p2 in zip(net1.parameters(), net2.parameters()): + self.assertEqual(p1.grad, p2.grad) + instantiate_parametrized_tests(TestCheckpoint) diff --git a/torch/distributed/_composable/checkpoint_activation.py b/torch/distributed/_composable/checkpoint_activation.py index 0bbc93d9345..2533fe37f93 100644 --- a/torch/distributed/_composable/checkpoint_activation.py +++ b/torch/distributed/_composable/checkpoint_activation.py @@ -47,7 +47,7 @@ class _ModuleHookCheckpointFunction(torch.autograd.Function): return output @staticmethod - def backward(ctx, output_grads: Tuple[Optional[torch.Tensor]]) -> Any: # type: ignore[override] + def backward(ctx, *output_grads: Tuple[Optional[torch.Tensor]]) -> Any: # type: ignore[override] if not torch.autograd._is_checkpoint_valid(): raise RuntimeError( "Checkpointing is not compatible with .grad() or when an "