mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[effects] Add way to register effectul op (#122348)
This adds a way to register an operator as being effectful. I also added a test case which mimics our solution for intermediate logging ([doc](https://docs.google.com/document/d/1eLyGDVe4iplVFiO0I021uLgA4Y6HxK9eqn55e9KzQkc/edit#heading=h.uwec2ukkwhea)), which is by creating a custom op and registering it as effectful. Pull Request resolved: https://github.com/pytorch/pytorch/pull/122348 Approved by: https://github.com/zou3519 ghstack dependencies: #122347
This commit is contained in:
parent
493478db4a
commit
7420b8c5be
|
|
@ -1,5 +1,8 @@
|
|||
# Owner(s): ["module: functorch"]
|
||||
import unittest
|
||||
from collections import deque
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
|
|
@ -11,6 +14,8 @@ from torch._higher_order_ops.effects import with_effects
|
|||
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_cuda import _get_torch_cuda_version, SM80OrLater
|
||||
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
|
||||
from torch.testing._internal.common_utils import (
|
||||
find_library_location,
|
||||
IS_FBCODE,
|
||||
|
|
@ -18,8 +23,11 @@ from torch.testing._internal.common_utils import (
|
|||
IS_SANDCASTLE,
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
TEST_CUDA,
|
||||
TEST_WITH_ROCM,
|
||||
TestCase,
|
||||
)
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
|
||||
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||
|
|
@ -217,6 +225,136 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||
|
||||
res.sum().backward()
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "triton")
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "triton")
|
||||
@unittest.skipIf(not SM80OrLater, "triton")
|
||||
@unittest.skipIf(_get_torch_cuda_version() >= (11, 7), "triton")
|
||||
@unittest.skipIf(not TEST_CUDA, "triton")
|
||||
@skipIfNoDynamoSupport
|
||||
def test_register_effectful_custom_op(self):
|
||||
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
||||
torch._dynamo.config.capture_scalar_outputs = True
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
|
||||
torch.library.define(
|
||||
"mylib::record_scalar_tensor",
|
||||
"(Tensor x, str prefix) -> ()",
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
# global variable to store the recorded tensor and prefix.
|
||||
recorded_dict = {}
|
||||
|
||||
# Pytorch custorm op implementation
|
||||
@torch.library.impl(
|
||||
"mylib::record_scalar_tensor",
|
||||
"CompositeExplicitAutograd",
|
||||
lib=lib,
|
||||
)
|
||||
def record_scalar_tensor(x, prefix):
|
||||
recorded_dict[prefix] = x.clone()
|
||||
return
|
||||
|
||||
# Meta function of the custom op
|
||||
@torch.library.impl_abstract(
|
||||
"mylib::record_scalar_tensor",
|
||||
lib=lib,
|
||||
)
|
||||
def record_scalar_tensor_meta(x, prefix):
|
||||
return
|
||||
|
||||
from torch._higher_order_ops.effects import (
|
||||
_EffectType,
|
||||
_register_effectful_op,
|
||||
)
|
||||
|
||||
_register_effectful_op(
|
||||
torch.ops.mylib.record_scalar_tensor.default, _EffectType.ORDERED
|
||||
)
|
||||
|
||||
my_config = {}
|
||||
my_config["MockModule"] = "mean"
|
||||
my_config["MockModule.linear"] = "mean"
|
||||
my_config["MockModule.relu"] = "mean"
|
||||
|
||||
class MyLinear(torch.nn.Module):
|
||||
def __init__(self, in_features, out_features):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.randn(out_features, in_features), requires_grad=True
|
||||
)
|
||||
self.bias = torch.nn.Parameter(
|
||||
torch.randn(out_features), requires_grad=True
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.linear(x, self.weight, self.bias)
|
||||
|
||||
class MockModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = MyLinear(10, 10)
|
||||
self.register_buffer(
|
||||
"buf0", torch.randn(10, 10, requires_grad=True)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.relu(self.linear(x) + self.buf0)
|
||||
|
||||
def forward_hook(
|
||||
module: torch.nn.Module,
|
||||
inputs: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
prefix: str,
|
||||
aggregate_method: str,
|
||||
) -> torch.Tensor:
|
||||
if aggregate_method == "mean":
|
||||
torch.ops.mylib.record_scalar_tensor(output.mean(), prefix)
|
||||
elif aggregate_method == "max":
|
||||
torch.ops.mylib.record_scalar_tensor(output.max(), prefix)
|
||||
else:
|
||||
# demo purpose, using "min"
|
||||
torch.ops.mylib.record_scalar_tensor(output.sum(), prefix)
|
||||
return output
|
||||
|
||||
def add_hooks(module, config):
|
||||
handles: List[RemovableHandle] = []
|
||||
q = deque([(module.__class__.__name__, module)])
|
||||
while q:
|
||||
name, m = q.pop()
|
||||
children = [(name + "." + n, y) for (n, y) in m.named_children()]
|
||||
q.extend(children)
|
||||
aggregate_method = config.get(name, "mean")
|
||||
prefix = name + ":" + aggregate_method
|
||||
handle = m.register_forward_hook(
|
||||
partial(
|
||||
forward_hook,
|
||||
prefix=prefix,
|
||||
aggregate_method=aggregate_method,
|
||||
)
|
||||
)
|
||||
if handle:
|
||||
handles.append(handle)
|
||||
return handles
|
||||
|
||||
x = torch.randn(10, 10, device="cuda")
|
||||
mod = MockModule().to("cuda")
|
||||
|
||||
add_hooks(mod, my_config)
|
||||
|
||||
opt_mod = torch.compile(backend="inductor")(mod)
|
||||
y = opt_mod(x)
|
||||
|
||||
self.assertTrue(torch.allclose(y, mod(x)))
|
||||
# Ensure it works well with backward
|
||||
y.sum().backward()
|
||||
# Ensure the grad is existing
|
||||
self.assertTrue(isinstance(opt_mod.linear.weight.grad, torch.Tensor))
|
||||
|
||||
self.assertEqual(len(recorded_dict), 2)
|
||||
self.assertTrue("MockModule.linear:mean" in recorded_dict)
|
||||
self.assertTrue("MockModule:mean" in recorded_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -22,6 +22,16 @@ SIDE_EFFECTS: Dict[torch._ops.OpOverload, _EffectType] = {
|
|||
}
|
||||
|
||||
|
||||
def _register_effectful_op(op: torch._ops.OpOverload, effect: _EffectType):
|
||||
assert isinstance(op, torch._ops.OpOverload) and not has_aliasing(op)
|
||||
if op in SIDE_EFFECTS and SIDE_EFFECTS[op] != effect:
|
||||
raise RuntimeError(
|
||||
f"Already registered effect type {SIDE_EFFECTS[op]} to op {op}, "
|
||||
f"trying to register a different effect type {effect}."
|
||||
)
|
||||
SIDE_EFFECTS[op] = effect
|
||||
|
||||
|
||||
class WithEffects(HigherOrderOperator):
|
||||
"""
|
||||
with_effects(token, op, args, kwargs) -> (new_token, op_results)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user