mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
cc @ezyang @gchanan Pull Request resolved: https://github.com/pytorch/pytorch/pull/87370 Approved by: https://github.com/soulitzer
197 lines
6.7 KiB
Python
197 lines
6.7 KiB
Python
# Owner(s): ["module: nn"]
|
|
from torch.testing._internal.common_utils import (
|
|
TestCase,
|
|
run_tests,
|
|
skipIfTorchDynamo,
|
|
)
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from functools import partial
|
|
from typing import List, Tuple
|
|
|
|
|
|
class Net(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.seq1 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)])
|
|
self.seq2 = nn.Sequential(*[nn.Linear(10, 10) for _ in range(2)])
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.seq2(self.seq1(x))
|
|
|
|
|
|
class ToyModel(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.net1 = Net()
|
|
self.net2 = Net()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.net2(self.net1(x))
|
|
|
|
|
|
def forward_hook(
|
|
self: TestCase,
|
|
fired_hooks: List[int],
|
|
expected_module: nn.Module,
|
|
hook_id: int,
|
|
module: nn.Module,
|
|
inp: Tuple[torch.Tensor],
|
|
out: torch.Tensor,
|
|
) -> None:
|
|
fired_hooks.append(hook_id)
|
|
self.assertEqual(id(module), id(expected_module))
|
|
self.assertEqual(len(inp), 1)
|
|
|
|
|
|
def forward_pre_hook(
|
|
self: TestCase,
|
|
fired_hooks: List[int],
|
|
expected_module: nn.Module,
|
|
hook_id: int,
|
|
module: nn.Module,
|
|
inp: Tuple[torch.Tensor],
|
|
) -> None:
|
|
fired_hooks.append(hook_id)
|
|
self.assertEqual(id(module), id(expected_module))
|
|
self.assertEqual(len(inp), 1)
|
|
|
|
|
|
def full_backward_hook(
|
|
self: TestCase,
|
|
fired_hooks: List[int],
|
|
expected_module: nn.Module,
|
|
hook_id: int,
|
|
module: nn.Module,
|
|
grad_input: Tuple[torch.Tensor],
|
|
grad_output: Tuple[torch.Tensor],
|
|
) -> None:
|
|
fired_hooks.append(hook_id)
|
|
self.assertEqual(id(module), id(expected_module))
|
|
self.assertEqual(len(grad_input), 1)
|
|
self.assertEqual(len(grad_output), 1)
|
|
|
|
|
|
def full_backward_pre_hook(
|
|
self: TestCase,
|
|
fired_hooks: List[int],
|
|
expected_module: nn.Module,
|
|
hook_id: int,
|
|
module: nn.Module,
|
|
grad_input: Tuple[torch.Tensor],
|
|
) -> None:
|
|
fired_hooks.append(hook_id)
|
|
self.assertEqual(id(module), id(expected_module))
|
|
self.assertEqual(len(grad_input), 1)
|
|
|
|
|
|
class TestModuleHooks(TestCase):
|
|
|
|
@skipIfTorchDynamo("Dynamo does not yet capture hooks")
|
|
def test_forward_hooks(self):
|
|
fired_hooks: List[int] = []
|
|
model = ToyModel()
|
|
x = torch.randn(10, 10)
|
|
hook = partial(forward_hook, self, fired_hooks, model.net1.seq2)
|
|
model.net1.seq2.register_forward_hook(partial(hook, 0))
|
|
model.net1.seq2.register_forward_hook(partial(hook, 1), prepend=True)
|
|
model.net1.seq2.register_forward_hook(partial(hook, 2))
|
|
model.net1.seq2.register_forward_hook(partial(hook, 3))
|
|
model.net1.seq2.register_forward_hook(partial(hook, 4), prepend=True)
|
|
expected = [4, 1, 0, 2, 3]
|
|
|
|
self.assertEqual(fired_hooks, [])
|
|
out = model(x)
|
|
self.assertEqual(fired_hooks, expected)
|
|
out.sum().backward()
|
|
self.assertEqual(fired_hooks, expected)
|
|
model(x).sum().backward()
|
|
self.assertEqual(fired_hooks, expected + expected)
|
|
|
|
@skipIfTorchDynamo("Dynamo does not yet capture hooks")
|
|
def test_forward_pre_hooks(self):
|
|
fired_hooks: List[int] = []
|
|
model = ToyModel()
|
|
x = torch.randn(10, 10)
|
|
hook = partial(forward_pre_hook, self, fired_hooks, model.net2.seq1)
|
|
model.net2.seq1.register_forward_pre_hook(partial(hook, 0), prepend=True)
|
|
model.net2.seq1.register_forward_pre_hook(partial(hook, 1))
|
|
model.net2.seq1.register_forward_pre_hook(partial(hook, 2))
|
|
model.net2.seq1.register_forward_pre_hook(partial(hook, 3))
|
|
model.net2.seq1.register_forward_pre_hook(partial(hook, 4), prepend=True)
|
|
expected = [4, 0, 1, 2, 3]
|
|
|
|
self.assertEqual(fired_hooks, [])
|
|
out = model(x)
|
|
self.assertEqual(fired_hooks, expected)
|
|
out.sum().backward()
|
|
self.assertEqual(fired_hooks, expected)
|
|
model(x).sum().backward()
|
|
self.assertEqual(fired_hooks, expected + expected)
|
|
|
|
@skipIfTorchDynamo("Dynamo does not yet capture hooks")
|
|
def test_full_backward_hooks(self):
|
|
fired_hooks: List[int] = []
|
|
model = ToyModel()
|
|
x = torch.randn(10, 10)
|
|
hook = partial(full_backward_hook, self, fired_hooks, model.net1)
|
|
model.net1.register_full_backward_hook(partial(hook, 0))
|
|
model.net1.register_full_backward_hook(partial(hook, 1))
|
|
model.net1.register_full_backward_hook(partial(hook, 2))
|
|
model.net1.register_full_backward_hook(partial(hook, 3), prepend=True)
|
|
model.net1.register_full_backward_hook(partial(hook, 4), prepend=True)
|
|
expected = [4, 3, 0, 1, 2]
|
|
|
|
self.assertEqual(fired_hooks, [])
|
|
out = model(x)
|
|
self.assertEqual(fired_hooks, [])
|
|
out.sum().backward()
|
|
self.assertEqual(fired_hooks, expected)
|
|
model(x).sum().backward()
|
|
self.assertEqual(fired_hooks, expected + expected)
|
|
|
|
@skipIfTorchDynamo("Dynamo does not yet capture hooks")
|
|
def test_full_backward_pre_hooks(self):
|
|
fired_hooks: List[int] = []
|
|
model = ToyModel()
|
|
x = torch.randn(10, 10)
|
|
hook = partial(full_backward_pre_hook, self, fired_hooks, model.net1)
|
|
model.net1.register_full_backward_pre_hook(partial(hook, 0), prepend=True)
|
|
model.net1.register_full_backward_pre_hook(partial(hook, 1), prepend=True)
|
|
model.net1.register_full_backward_pre_hook(partial(hook, 2))
|
|
model.net1.register_full_backward_pre_hook(partial(hook, 3))
|
|
model.net1.register_full_backward_pre_hook(partial(hook, 4))
|
|
expected = [1, 0, 2, 3, 4]
|
|
|
|
self.assertEqual(fired_hooks, [])
|
|
out = model(x)
|
|
self.assertEqual(fired_hooks, [])
|
|
out.sum().backward()
|
|
self.assertEqual(fired_hooks, expected)
|
|
model(x).sum().backward()
|
|
self.assertEqual(fired_hooks, expected + expected)
|
|
|
|
@skipIfTorchDynamo("Dynamo does not yet capture hooks")
|
|
def test_mixed_hooks(self):
|
|
fired_hooks: List[int] = []
|
|
model = ToyModel()
|
|
x = torch.randn(10, 10)
|
|
model.register_forward_pre_hook(partial(forward_pre_hook, self, fired_hooks, model, 0))
|
|
model.register_forward_hook(partial(forward_hook, self, fired_hooks, model, 1))
|
|
model.register_full_backward_pre_hook(partial(full_backward_pre_hook, self, fired_hooks, model, 2))
|
|
model.register_full_backward_hook(partial(full_backward_hook, self, fired_hooks, model, 3))
|
|
|
|
self.assertEqual(fired_hooks, [])
|
|
out = model(x)
|
|
self.assertEqual(fired_hooks, [0, 1])
|
|
out.sum().backward()
|
|
self.assertEqual(fired_hooks, [0, 1, 2, 3])
|
|
model(x).sum().backward()
|
|
self.assertEqual(fired_hooks, [0, 1, 2, 3, 0, 1, 2, 3])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|