[effects] Add way to register effectul op (#122348)

This adds a way to register an operator as being effectful. I also added a test case which mimics our solution for intermediate logging ([doc](https://docs.google.com/document/d/1eLyGDVe4iplVFiO0I021uLgA4Y6HxK9eqn55e9KzQkc/edit#heading=h.uwec2ukkwhea)), which is by creating a custom op and registering it as effectful.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122348
Approved by: https://github.com/zou3519
ghstack dependencies: #122347
This commit is contained in:
angelayi 2024-04-08 14:40:04 -07:00 committed by PyTorch MergeBot
parent 493478db4a
commit 7420b8c5be
2 changed files with 148 additions and 0 deletions

View File

@ -1,5 +1,8 @@
# Owner(s): ["module: functorch"]
import unittest
from collections import deque
from functools import partial
from typing import List
import torch
import torch._dynamo
@ -11,6 +14,8 @@ from torch._higher_order_ops.effects import with_effects
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import _get_torch_cuda_version, SM80OrLater
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
from torch.testing._internal.common_utils import (
find_library_location,
IS_FBCODE,
@ -18,8 +23,11 @@ from torch.testing._internal.common_utils import (
IS_SANDCASTLE,
IS_WINDOWS,
run_tests,
TEST_CUDA,
TEST_WITH_ROCM,
TestCase,
)
from torch.utils.hooks import RemovableHandle
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support")
@ -217,6 +225,136 @@ def forward(self, arg0_1, arg1_1, arg2_1):
res.sum().backward()
@unittest.skipIf(IS_WINDOWS, "triton")
@unittest.skipIf(TEST_WITH_ROCM, "triton")
@unittest.skipIf(not SM80OrLater, "triton")
@unittest.skipIf(_get_torch_cuda_version() >= (11, 7), "triton")
@unittest.skipIf(not TEST_CUDA, "triton")
@skipIfNoDynamoSupport
def test_register_effectful_custom_op(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch.library.define(
"mylib::record_scalar_tensor",
"(Tensor x, str prefix) -> ()",
lib=lib,
)
# global variable to store the recorded tensor and prefix.
recorded_dict = {}
# Pytorch custorm op implementation
@torch.library.impl(
"mylib::record_scalar_tensor",
"CompositeExplicitAutograd",
lib=lib,
)
def record_scalar_tensor(x, prefix):
recorded_dict[prefix] = x.clone()
return
# Meta function of the custom op
@torch.library.impl_abstract(
"mylib::record_scalar_tensor",
lib=lib,
)
def record_scalar_tensor_meta(x, prefix):
return
from torch._higher_order_ops.effects import (
_EffectType,
_register_effectful_op,
)
_register_effectful_op(
torch.ops.mylib.record_scalar_tensor.default, _EffectType.ORDERED
)
my_config = {}
my_config["MockModule"] = "mean"
my_config["MockModule.linear"] = "mean"
my_config["MockModule.relu"] = "mean"
class MyLinear(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = torch.nn.Parameter(
torch.randn(out_features, in_features), requires_grad=True
)
self.bias = torch.nn.Parameter(
torch.randn(out_features), requires_grad=True
)
def forward(self, x):
return torch.nn.functional.linear(x, self.weight, self.bias)
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = MyLinear(10, 10)
self.register_buffer(
"buf0", torch.randn(10, 10, requires_grad=True)
)
def forward(self, x):
return torch.nn.functional.relu(self.linear(x) + self.buf0)
def forward_hook(
module: torch.nn.Module,
inputs: torch.Tensor,
output: torch.Tensor,
prefix: str,
aggregate_method: str,
) -> torch.Tensor:
if aggregate_method == "mean":
torch.ops.mylib.record_scalar_tensor(output.mean(), prefix)
elif aggregate_method == "max":
torch.ops.mylib.record_scalar_tensor(output.max(), prefix)
else:
# demo purpose, using "min"
torch.ops.mylib.record_scalar_tensor(output.sum(), prefix)
return output
def add_hooks(module, config):
handles: List[RemovableHandle] = []
q = deque([(module.__class__.__name__, module)])
while q:
name, m = q.pop()
children = [(name + "." + n, y) for (n, y) in m.named_children()]
q.extend(children)
aggregate_method = config.get(name, "mean")
prefix = name + ":" + aggregate_method
handle = m.register_forward_hook(
partial(
forward_hook,
prefix=prefix,
aggregate_method=aggregate_method,
)
)
if handle:
handles.append(handle)
return handles
x = torch.randn(10, 10, device="cuda")
mod = MockModule().to("cuda")
add_hooks(mod, my_config)
opt_mod = torch.compile(backend="inductor")(mod)
y = opt_mod(x)
self.assertTrue(torch.allclose(y, mod(x)))
# Ensure it works well with backward
y.sum().backward()
# Ensure the grad is existing
self.assertTrue(isinstance(opt_mod.linear.weight.grad, torch.Tensor))
self.assertEqual(len(recorded_dict), 2)
self.assertTrue("MockModule.linear:mean" in recorded_dict)
self.assertTrue("MockModule:mean" in recorded_dict)
if __name__ == "__main__":
run_tests()

View File

@ -22,6 +22,16 @@ SIDE_EFFECTS: Dict[torch._ops.OpOverload, _EffectType] = {
}
def _register_effectful_op(op: torch._ops.OpOverload, effect: _EffectType):
assert isinstance(op, torch._ops.OpOverload) and not has_aliasing(op)
if op in SIDE_EFFECTS and SIDE_EFFECTS[op] != effect:
raise RuntimeError(
f"Already registered effect type {SIDE_EFFECTS[op]} to op {op}, "
f"trying to register a different effect type {effect}."
)
SIDE_EFFECTS[op] = effect
class WithEffects(HigherOrderOperator):
"""
with_effects(token, op, args, kwargs) -> (new_token, op_results)