[fx] Add hooks to intercept node replacements. (#117825)

Summary: Adding an experimental API to FX graph module to place "hooks" every time when we are changing or replacing nodes in a graph, so that we can properly update the new name in graph signature and potentially other places.

Test Plan:
buck test mode/opt  -c fbcode.enable_gpu_sections=true caffe2/test/distributed/_tensor/experimental:tp_transform

buck test mode/opt caffe2/test:test_export -- -r test_replace_hook

Differential Revision: D52896531

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117825
Approved by: https://github.com/avikchaudhuri
This commit is contained in:
Zhengxu Chen 2024-01-23 22:28:40 +00:00 committed by PyTorch MergeBot
parent b369888bec
commit abd759d50d
7 changed files with 120 additions and 8 deletions

View File

@ -570,6 +570,7 @@ API Reference
.. autoclass:: ExportGraphSignature .. autoclass:: ExportGraphSignature
.. automethod:: replace_all_uses .. automethod:: replace_all_uses
.. automethod:: get_replace_hook
.. autoclass:: torch.export.graph_signature.CustomObjArgument .. autoclass:: torch.export.graph_signature.CustomObjArgument

View File

@ -1,11 +1,13 @@
# Owner(s): ["oncall: export"] # Owner(s): ["oncall: export"]
import copy
import unittest import unittest
import torch import torch
from functorch.experimental import control_flow from functorch.experimental import control_flow
from torch._dynamo.eval_frame import is_dynamo_supported from torch._dynamo.eval_frame import is_dynamo_supported
from torch.export import export
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
from torch.export import export
from torch.fx.passes.infra.pass_base import PassResult
from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.common_utils import run_tests, TestCase
@ -141,6 +143,45 @@ class TestPassInfra(TestCase):
old_signature = ep_before.graph_signature old_signature = ep_before.graph_signature
self.assertNotEqual(new_signature.user_outputs, old_signature.user_outputs) self.assertNotEqual(new_signature.user_outputs, old_signature.user_outputs)
def test_replace_hook_basic(self) -> None:
class CustomModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
self.register_buffer("my_buffer1", torch.tensor(3.0))
self.register_buffer("my_buffer2", torch.tensor(4.0))
def forward(self, x1, x2):
# Use the parameter, buffers, and both inputs in the forward method
output = (
x1 + self.my_parameter
) * self.my_buffer1 + x2 * self.my_buffer2
return output
my_module = CustomModule()
inputs = (torch.tensor(6.0), torch.tensor(7.0))
ep_before = export(my_module, inputs)
def replace_pass(gm):
for node in gm.graph.nodes:
if node.op == "call_function":
node.name = node.name + "_modified"
gm.recompile()
return PassResult(gm, True)
gm = copy.deepcopy(ep_before.graph_module)
sig = copy.deepcopy(ep_before.graph_signature)
with gm._set_replace_hook(sig.get_replace_hook()):
replace_pass(gm)
for node_name in sig.user_outputs:
self.assertTrue("_modified" in node_name)
old_signature = ep_before.graph_signature
self.assertNotEqual(sig.user_outputs, old_signature.user_outputs)
if __name__ == '__main__': if __name__ == '__main__':
run_tests() run_tests()

View File

@ -47,17 +47,24 @@ def tensor_parallel_transformation(
.. warning:: .. warning::
This API is experimental and subject to change. This API is experimental and subject to change.
""" """
# TODO Migrate this to plain function call.
return exported_program._transform_do_not_use( gm = exported_program.graph_module
TensorParallelTransformPass( sig = copy.deepcopy(exported_program.graph_signature)
state_dict = copy.copy(exported_program.state_dict)
with gm._set_replace_hook(sig.get_replace_hook()):
res = TensorParallelTransformPass(
rank, rank,
world_size, world_size,
device_type, device_type,
exported_program.state_dict, state_dict,
exported_program.graph_signature, exported_program.graph_signature,
parallel_strategies, parallel_strategies,
) )(gm)
) assert res is not None
gm = res.graph_module
return exported_program._update(gm, sig, state_dict)
class TensorParallelTransformPass(PassBase): class TensorParallelTransformPass(PassBase):

View File

@ -579,6 +579,22 @@ class ExportedProgram:
def _validate(self): def _validate(self):
self.verifier().check(self) self.verifier().check(self)
# TODO(zhxchen17) Formalize this.
def _update(
self, graph_module, graph_signature, state_dict=None
) -> "ExportedProgram":
return ExportedProgram(
root=graph_module,
graph=graph_module.graph,
graph_signature=graph_signature,
state_dict=state_dict or self.state_dict,
range_constraints=copy.deepcopy(self.range_constraints),
module_call_graph=copy.deepcopy(self._module_call_graph),
example_inputs=self.example_inputs,
verifier=self.verifier,
tensor_constants=self.tensor_constants,
)
def _get_updated_range_constraints( def _get_updated_range_constraints(
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,

View File

@ -424,7 +424,19 @@ class ExportGraphSignature:
""" """
assert isinstance(old, str) assert isinstance(old, str)
assert isinstance(new, str) assert isinstance(new, str)
arg_types = (TensorArgument, SymIntArgument, CustomObjArgument)
for o in self.output_specs: for o in self.output_specs:
if isinstance(o.arg, TensorArgument): if isinstance(o.arg, arg_types):
if o.arg.name == old: if o.arg.name == old:
o.arg.name = new o.arg.name = new
for i in self.input_specs:
if isinstance(i.arg, arg_types):
if i.arg.name == old:
i.arg.name = new
def get_replace_hook(self):
def _(old, new, user):
if user.op in ("output", "input"):
self.replace_all_uses(old.name, new)
return _

View File

@ -1,3 +1,4 @@
import contextlib
import copy import copy
import itertools import itertools
import linecache import linecache
@ -445,6 +446,7 @@ class GraphModule(torch.nn.Module):
# Dictionary to store metadata # Dictionary to store metadata
self.meta: Dict[str, Any] = {} self.meta: Dict[str, Any] = {}
self._replace_hook = None
# TorchScript breaks trying to compile the graph setter because of the # TorchScript breaks trying to compile the graph setter because of the
# continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842 # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
@ -799,6 +801,7 @@ class {module_name}(torch.nn.Module):
"_state_dict_hooks", "_state_dict_hooks",
"_load_state_dict_pre_hooks", "_load_state_dict_pre_hooks",
"_load_state_dict_post_hooks", "_load_state_dict_post_hooks",
"_replace_hook",
] ]
for attr in extra_preserved_attrs: for attr in extra_preserved_attrs:
if attr in self.__dict__: if attr in self.__dict__:
@ -849,6 +852,21 @@ class {module_name}(torch.nn.Module):
new_gm._is_replica = True new_gm._is_replica = True
return new_gm return new_gm
@contextlib.contextmanager
def _set_replace_hook(self, f):
"""
Takes a callable which will be called everytime when we replace a node
to a new node, or change the node's name. Callable takes three arguments:
the old node we're changing, and NAME of the new node, followed by the
user node which consumes the old node to be replaced.
"""
assert callable(f), "Replace hook must be a callable."
prev, self._replace_hook = self._replace_hook, f
try:
yield
finally:
self._replace_hook = prev
# workarounds for issues in __torch_function__ # workarounds for issues in __torch_function__

View File

@ -562,6 +562,7 @@ class Node:
replace_with.meta[k] = v replace_with.meta[k] = v
to_process = list(self.users) to_process = list(self.users)
skipped = [] skipped = []
m = self.graph.owning_module
for use_node in to_process: for use_node in to_process:
if not delete_user_cb(use_node): if not delete_user_cb(use_node):
skipped.append(use_node) skipped.append(use_node)
@ -573,6 +574,9 @@ class Node:
else: else:
return n return n
if getattr(m, "_replace_hook", None):
m._replace_hook(old=self, new=replace_with.name, user=use_node)
new_args = map_arg(use_node.args, maybe_replace_node) new_args = map_arg(use_node.args, maybe_replace_node)
new_kwargs = map_arg(use_node.kwargs, maybe_replace_node) new_kwargs = map_arg(use_node.kwargs, maybe_replace_node)
assert isinstance(new_args, tuple) assert isinstance(new_args, tuple)
@ -662,6 +666,10 @@ class Node:
def maybe_replace_node(n : Node) -> Node: def maybe_replace_node(n : Node) -> Node:
return new_input if n == old_input else n return new_input if n == old_input else n
m = self.graph.owning_module
if getattr(m, "_replace_hook", None):
m._replace_hook(old=old_input, new=new_input.name, user=self)
new_args = map_arg(self.args, maybe_replace_node) new_args = map_arg(self.args, maybe_replace_node)
new_kwargs = map_arg(self.kwargs, maybe_replace_node) new_kwargs = map_arg(self.kwargs, maybe_replace_node)
assert isinstance(new_args, tuple) assert isinstance(new_args, tuple)
@ -675,6 +683,15 @@ class Node:
self.name = name self.name = name
self.graph._graph_namespace._rename_object(self, name) self.graph._graph_namespace._rename_object(self, name)
def __setattr__(self, name: str, value: Any) -> None:
if name == 'name' and hasattr(self, "name"):
m = self.graph.owning_module
if getattr(m, "_replace_hook", None):
assert isinstance(value, str)
for user in self.users:
m._replace_hook(old=self, new=value, user=user)
object.__setattr__(self, name, value)
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: