mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[fx] Add hooks to intercept node replacements. (#117825)
Summary: Adding an experimental API to FX graph module to place "hooks" every time when we are changing or replacing nodes in a graph, so that we can properly update the new name in graph signature and potentially other places. Test Plan: buck test mode/opt -c fbcode.enable_gpu_sections=true caffe2/test/distributed/_tensor/experimental:tp_transform buck test mode/opt caffe2/test:test_export -- -r test_replace_hook Differential Revision: D52896531 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117825 Approved by: https://github.com/avikchaudhuri
This commit is contained in:
parent
b369888bec
commit
abd759d50d
|
|
@ -570,6 +570,7 @@ API Reference
|
|||
.. autoclass:: ExportGraphSignature
|
||||
|
||||
.. automethod:: replace_all_uses
|
||||
.. automethod:: get_replace_hook
|
||||
|
||||
.. autoclass:: torch.export.graph_signature.CustomObjArgument
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
# Owner(s): ["oncall: export"]
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from functorch.experimental import control_flow
|
||||
from torch._dynamo.eval_frame import is_dynamo_supported
|
||||
from torch.export import export
|
||||
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
|
||||
from torch.export import export
|
||||
from torch.fx.passes.infra.pass_base import PassResult
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
|
|
@ -141,6 +143,45 @@ class TestPassInfra(TestCase):
|
|||
old_signature = ep_before.graph_signature
|
||||
self.assertNotEqual(new_signature.user_outputs, old_signature.user_outputs)
|
||||
|
||||
def test_replace_hook_basic(self) -> None:
|
||||
class CustomModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
|
||||
|
||||
self.register_buffer("my_buffer1", torch.tensor(3.0))
|
||||
self.register_buffer("my_buffer2", torch.tensor(4.0))
|
||||
|
||||
def forward(self, x1, x2):
|
||||
# Use the parameter, buffers, and both inputs in the forward method
|
||||
output = (
|
||||
x1 + self.my_parameter
|
||||
) * self.my_buffer1 + x2 * self.my_buffer2
|
||||
return output
|
||||
|
||||
my_module = CustomModule()
|
||||
inputs = (torch.tensor(6.0), torch.tensor(7.0))
|
||||
ep_before = export(my_module, inputs)
|
||||
|
||||
def replace_pass(gm):
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_function":
|
||||
node.name = node.name + "_modified"
|
||||
gm.recompile()
|
||||
return PassResult(gm, True)
|
||||
|
||||
gm = copy.deepcopy(ep_before.graph_module)
|
||||
sig = copy.deepcopy(ep_before.graph_signature)
|
||||
|
||||
with gm._set_replace_hook(sig.get_replace_hook()):
|
||||
replace_pass(gm)
|
||||
|
||||
for node_name in sig.user_outputs:
|
||||
self.assertTrue("_modified" in node_name)
|
||||
|
||||
old_signature = ep_before.graph_signature
|
||||
self.assertNotEqual(sig.user_outputs, old_signature.user_outputs)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -47,17 +47,24 @@ def tensor_parallel_transformation(
|
|||
.. warning::
|
||||
This API is experimental and subject to change.
|
||||
"""
|
||||
# TODO Migrate this to plain function call.
|
||||
return exported_program._transform_do_not_use(
|
||||
TensorParallelTransformPass(
|
||||
|
||||
gm = exported_program.graph_module
|
||||
sig = copy.deepcopy(exported_program.graph_signature)
|
||||
state_dict = copy.copy(exported_program.state_dict)
|
||||
|
||||
with gm._set_replace_hook(sig.get_replace_hook()):
|
||||
res = TensorParallelTransformPass(
|
||||
rank,
|
||||
world_size,
|
||||
device_type,
|
||||
exported_program.state_dict,
|
||||
state_dict,
|
||||
exported_program.graph_signature,
|
||||
parallel_strategies,
|
||||
)
|
||||
)
|
||||
)(gm)
|
||||
assert res is not None
|
||||
gm = res.graph_module
|
||||
|
||||
return exported_program._update(gm, sig, state_dict)
|
||||
|
||||
|
||||
class TensorParallelTransformPass(PassBase):
|
||||
|
|
|
|||
|
|
@ -579,6 +579,22 @@ class ExportedProgram:
|
|||
def _validate(self):
|
||||
self.verifier().check(self)
|
||||
|
||||
# TODO(zhxchen17) Formalize this.
|
||||
def _update(
|
||||
self, graph_module, graph_signature, state_dict=None
|
||||
) -> "ExportedProgram":
|
||||
return ExportedProgram(
|
||||
root=graph_module,
|
||||
graph=graph_module.graph,
|
||||
graph_signature=graph_signature,
|
||||
state_dict=state_dict or self.state_dict,
|
||||
range_constraints=copy.deepcopy(self.range_constraints),
|
||||
module_call_graph=copy.deepcopy(self._module_call_graph),
|
||||
example_inputs=self.example_inputs,
|
||||
verifier=self.verifier,
|
||||
tensor_constants=self.tensor_constants,
|
||||
)
|
||||
|
||||
|
||||
def _get_updated_range_constraints(
|
||||
gm: torch.fx.GraphModule,
|
||||
|
|
|
|||
|
|
@ -424,7 +424,19 @@ class ExportGraphSignature:
|
|||
"""
|
||||
assert isinstance(old, str)
|
||||
assert isinstance(new, str)
|
||||
arg_types = (TensorArgument, SymIntArgument, CustomObjArgument)
|
||||
for o in self.output_specs:
|
||||
if isinstance(o.arg, TensorArgument):
|
||||
if isinstance(o.arg, arg_types):
|
||||
if o.arg.name == old:
|
||||
o.arg.name = new
|
||||
for i in self.input_specs:
|
||||
if isinstance(i.arg, arg_types):
|
||||
if i.arg.name == old:
|
||||
i.arg.name = new
|
||||
|
||||
def get_replace_hook(self):
|
||||
def _(old, new, user):
|
||||
if user.op in ("output", "input"):
|
||||
self.replace_all_uses(old.name, new)
|
||||
|
||||
return _
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import contextlib
|
||||
import copy
|
||||
import itertools
|
||||
import linecache
|
||||
|
|
@ -445,6 +446,7 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
# Dictionary to store metadata
|
||||
self.meta: Dict[str, Any] = {}
|
||||
self._replace_hook = None
|
||||
|
||||
# TorchScript breaks trying to compile the graph setter because of the
|
||||
# continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
|
||||
|
|
@ -799,6 +801,7 @@ class {module_name}(torch.nn.Module):
|
|||
"_state_dict_hooks",
|
||||
"_load_state_dict_pre_hooks",
|
||||
"_load_state_dict_post_hooks",
|
||||
"_replace_hook",
|
||||
]
|
||||
for attr in extra_preserved_attrs:
|
||||
if attr in self.__dict__:
|
||||
|
|
@ -849,6 +852,21 @@ class {module_name}(torch.nn.Module):
|
|||
new_gm._is_replica = True
|
||||
return new_gm
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _set_replace_hook(self, f):
|
||||
"""
|
||||
Takes a callable which will be called everytime when we replace a node
|
||||
to a new node, or change the node's name. Callable takes three arguments:
|
||||
the old node we're changing, and NAME of the new node, followed by the
|
||||
user node which consumes the old node to be replaced.
|
||||
"""
|
||||
assert callable(f), "Replace hook must be a callable."
|
||||
prev, self._replace_hook = self._replace_hook, f
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._replace_hook = prev
|
||||
|
||||
|
||||
# workarounds for issues in __torch_function__
|
||||
|
||||
|
|
|
|||
|
|
@ -562,6 +562,7 @@ class Node:
|
|||
replace_with.meta[k] = v
|
||||
to_process = list(self.users)
|
||||
skipped = []
|
||||
m = self.graph.owning_module
|
||||
for use_node in to_process:
|
||||
if not delete_user_cb(use_node):
|
||||
skipped.append(use_node)
|
||||
|
|
@ -573,6 +574,9 @@ class Node:
|
|||
else:
|
||||
return n
|
||||
|
||||
if getattr(m, "_replace_hook", None):
|
||||
m._replace_hook(old=self, new=replace_with.name, user=use_node)
|
||||
|
||||
new_args = map_arg(use_node.args, maybe_replace_node)
|
||||
new_kwargs = map_arg(use_node.kwargs, maybe_replace_node)
|
||||
assert isinstance(new_args, tuple)
|
||||
|
|
@ -662,6 +666,10 @@ class Node:
|
|||
def maybe_replace_node(n : Node) -> Node:
|
||||
return new_input if n == old_input else n
|
||||
|
||||
m = self.graph.owning_module
|
||||
if getattr(m, "_replace_hook", None):
|
||||
m._replace_hook(old=old_input, new=new_input.name, user=self)
|
||||
|
||||
new_args = map_arg(self.args, maybe_replace_node)
|
||||
new_kwargs = map_arg(self.kwargs, maybe_replace_node)
|
||||
assert isinstance(new_args, tuple)
|
||||
|
|
@ -675,6 +683,15 @@ class Node:
|
|||
self.name = name
|
||||
self.graph._graph_namespace._rename_object(self, name)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name == 'name' and hasattr(self, "name"):
|
||||
m = self.graph.owning_module
|
||||
if getattr(m, "_replace_hook", None):
|
||||
assert isinstance(value, str)
|
||||
for user in self.users:
|
||||
m._replace_hook(old=self, new=value, user=user)
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user