mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[fx] Add hooks to intercept node replacements. (#117825)
Summary: Adding an experimental API to FX graph module to place "hooks" every time when we are changing or replacing nodes in a graph, so that we can properly update the new name in graph signature and potentially other places. Test Plan: buck test mode/opt -c fbcode.enable_gpu_sections=true caffe2/test/distributed/_tensor/experimental:tp_transform buck test mode/opt caffe2/test:test_export -- -r test_replace_hook Differential Revision: D52896531 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117825 Approved by: https://github.com/avikchaudhuri
This commit is contained in:
parent
b369888bec
commit
abd759d50d
|
|
@ -570,6 +570,7 @@ API Reference
|
||||||
.. autoclass:: ExportGraphSignature
|
.. autoclass:: ExportGraphSignature
|
||||||
|
|
||||||
.. automethod:: replace_all_uses
|
.. automethod:: replace_all_uses
|
||||||
|
.. automethod:: get_replace_hook
|
||||||
|
|
||||||
.. autoclass:: torch.export.graph_signature.CustomObjArgument
|
.. autoclass:: torch.export.graph_signature.CustomObjArgument
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
# Owner(s): ["oncall: export"]
|
# Owner(s): ["oncall: export"]
|
||||||
|
import copy
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from functorch.experimental import control_flow
|
from functorch.experimental import control_flow
|
||||||
from torch._dynamo.eval_frame import is_dynamo_supported
|
from torch._dynamo.eval_frame import is_dynamo_supported
|
||||||
from torch.export import export
|
|
||||||
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
|
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
|
||||||
|
from torch.export import export
|
||||||
|
from torch.fx.passes.infra.pass_base import PassResult
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -141,6 +143,45 @@ class TestPassInfra(TestCase):
|
||||||
old_signature = ep_before.graph_signature
|
old_signature = ep_before.graph_signature
|
||||||
self.assertNotEqual(new_signature.user_outputs, old_signature.user_outputs)
|
self.assertNotEqual(new_signature.user_outputs, old_signature.user_outputs)
|
||||||
|
|
||||||
|
def test_replace_hook_basic(self) -> None:
|
||||||
|
class CustomModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
|
||||||
|
|
||||||
|
self.register_buffer("my_buffer1", torch.tensor(3.0))
|
||||||
|
self.register_buffer("my_buffer2", torch.tensor(4.0))
|
||||||
|
|
||||||
|
def forward(self, x1, x2):
|
||||||
|
# Use the parameter, buffers, and both inputs in the forward method
|
||||||
|
output = (
|
||||||
|
x1 + self.my_parameter
|
||||||
|
) * self.my_buffer1 + x2 * self.my_buffer2
|
||||||
|
return output
|
||||||
|
|
||||||
|
my_module = CustomModule()
|
||||||
|
inputs = (torch.tensor(6.0), torch.tensor(7.0))
|
||||||
|
ep_before = export(my_module, inputs)
|
||||||
|
|
||||||
|
def replace_pass(gm):
|
||||||
|
for node in gm.graph.nodes:
|
||||||
|
if node.op == "call_function":
|
||||||
|
node.name = node.name + "_modified"
|
||||||
|
gm.recompile()
|
||||||
|
return PassResult(gm, True)
|
||||||
|
|
||||||
|
gm = copy.deepcopy(ep_before.graph_module)
|
||||||
|
sig = copy.deepcopy(ep_before.graph_signature)
|
||||||
|
|
||||||
|
with gm._set_replace_hook(sig.get_replace_hook()):
|
||||||
|
replace_pass(gm)
|
||||||
|
|
||||||
|
for node_name in sig.user_outputs:
|
||||||
|
self.assertTrue("_modified" in node_name)
|
||||||
|
|
||||||
|
old_signature = ep_before.graph_signature
|
||||||
|
self.assertNotEqual(sig.user_outputs, old_signature.user_outputs)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -47,17 +47,24 @@ def tensor_parallel_transformation(
|
||||||
.. warning::
|
.. warning::
|
||||||
This API is experimental and subject to change.
|
This API is experimental and subject to change.
|
||||||
"""
|
"""
|
||||||
# TODO Migrate this to plain function call.
|
|
||||||
return exported_program._transform_do_not_use(
|
gm = exported_program.graph_module
|
||||||
TensorParallelTransformPass(
|
sig = copy.deepcopy(exported_program.graph_signature)
|
||||||
|
state_dict = copy.copy(exported_program.state_dict)
|
||||||
|
|
||||||
|
with gm._set_replace_hook(sig.get_replace_hook()):
|
||||||
|
res = TensorParallelTransformPass(
|
||||||
rank,
|
rank,
|
||||||
world_size,
|
world_size,
|
||||||
device_type,
|
device_type,
|
||||||
exported_program.state_dict,
|
state_dict,
|
||||||
exported_program.graph_signature,
|
exported_program.graph_signature,
|
||||||
parallel_strategies,
|
parallel_strategies,
|
||||||
)
|
)(gm)
|
||||||
)
|
assert res is not None
|
||||||
|
gm = res.graph_module
|
||||||
|
|
||||||
|
return exported_program._update(gm, sig, state_dict)
|
||||||
|
|
||||||
|
|
||||||
class TensorParallelTransformPass(PassBase):
|
class TensorParallelTransformPass(PassBase):
|
||||||
|
|
|
||||||
|
|
@ -579,6 +579,22 @@ class ExportedProgram:
|
||||||
def _validate(self):
|
def _validate(self):
|
||||||
self.verifier().check(self)
|
self.verifier().check(self)
|
||||||
|
|
||||||
|
# TODO(zhxchen17) Formalize this.
|
||||||
|
def _update(
|
||||||
|
self, graph_module, graph_signature, state_dict=None
|
||||||
|
) -> "ExportedProgram":
|
||||||
|
return ExportedProgram(
|
||||||
|
root=graph_module,
|
||||||
|
graph=graph_module.graph,
|
||||||
|
graph_signature=graph_signature,
|
||||||
|
state_dict=state_dict or self.state_dict,
|
||||||
|
range_constraints=copy.deepcopy(self.range_constraints),
|
||||||
|
module_call_graph=copy.deepcopy(self._module_call_graph),
|
||||||
|
example_inputs=self.example_inputs,
|
||||||
|
verifier=self.verifier,
|
||||||
|
tensor_constants=self.tensor_constants,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_updated_range_constraints(
|
def _get_updated_range_constraints(
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
|
|
|
||||||
|
|
@ -424,7 +424,19 @@ class ExportGraphSignature:
|
||||||
"""
|
"""
|
||||||
assert isinstance(old, str)
|
assert isinstance(old, str)
|
||||||
assert isinstance(new, str)
|
assert isinstance(new, str)
|
||||||
|
arg_types = (TensorArgument, SymIntArgument, CustomObjArgument)
|
||||||
for o in self.output_specs:
|
for o in self.output_specs:
|
||||||
if isinstance(o.arg, TensorArgument):
|
if isinstance(o.arg, arg_types):
|
||||||
if o.arg.name == old:
|
if o.arg.name == old:
|
||||||
o.arg.name = new
|
o.arg.name = new
|
||||||
|
for i in self.input_specs:
|
||||||
|
if isinstance(i.arg, arg_types):
|
||||||
|
if i.arg.name == old:
|
||||||
|
i.arg.name = new
|
||||||
|
|
||||||
|
def get_replace_hook(self):
|
||||||
|
def _(old, new, user):
|
||||||
|
if user.op in ("output", "input"):
|
||||||
|
self.replace_all_uses(old.name, new)
|
||||||
|
|
||||||
|
return _
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
import linecache
|
import linecache
|
||||||
|
|
@ -445,6 +446,7 @@ class GraphModule(torch.nn.Module):
|
||||||
|
|
||||||
# Dictionary to store metadata
|
# Dictionary to store metadata
|
||||||
self.meta: Dict[str, Any] = {}
|
self.meta: Dict[str, Any] = {}
|
||||||
|
self._replace_hook = None
|
||||||
|
|
||||||
# TorchScript breaks trying to compile the graph setter because of the
|
# TorchScript breaks trying to compile the graph setter because of the
|
||||||
# continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
|
# continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
|
||||||
|
|
@ -799,6 +801,7 @@ class {module_name}(torch.nn.Module):
|
||||||
"_state_dict_hooks",
|
"_state_dict_hooks",
|
||||||
"_load_state_dict_pre_hooks",
|
"_load_state_dict_pre_hooks",
|
||||||
"_load_state_dict_post_hooks",
|
"_load_state_dict_post_hooks",
|
||||||
|
"_replace_hook",
|
||||||
]
|
]
|
||||||
for attr in extra_preserved_attrs:
|
for attr in extra_preserved_attrs:
|
||||||
if attr in self.__dict__:
|
if attr in self.__dict__:
|
||||||
|
|
@ -849,6 +852,21 @@ class {module_name}(torch.nn.Module):
|
||||||
new_gm._is_replica = True
|
new_gm._is_replica = True
|
||||||
return new_gm
|
return new_gm
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _set_replace_hook(self, f):
|
||||||
|
"""
|
||||||
|
Takes a callable which will be called everytime when we replace a node
|
||||||
|
to a new node, or change the node's name. Callable takes three arguments:
|
||||||
|
the old node we're changing, and NAME of the new node, followed by the
|
||||||
|
user node which consumes the old node to be replaced.
|
||||||
|
"""
|
||||||
|
assert callable(f), "Replace hook must be a callable."
|
||||||
|
prev, self._replace_hook = self._replace_hook, f
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self._replace_hook = prev
|
||||||
|
|
||||||
|
|
||||||
# workarounds for issues in __torch_function__
|
# workarounds for issues in __torch_function__
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -562,6 +562,7 @@ class Node:
|
||||||
replace_with.meta[k] = v
|
replace_with.meta[k] = v
|
||||||
to_process = list(self.users)
|
to_process = list(self.users)
|
||||||
skipped = []
|
skipped = []
|
||||||
|
m = self.graph.owning_module
|
||||||
for use_node in to_process:
|
for use_node in to_process:
|
||||||
if not delete_user_cb(use_node):
|
if not delete_user_cb(use_node):
|
||||||
skipped.append(use_node)
|
skipped.append(use_node)
|
||||||
|
|
@ -573,6 +574,9 @@ class Node:
|
||||||
else:
|
else:
|
||||||
return n
|
return n
|
||||||
|
|
||||||
|
if getattr(m, "_replace_hook", None):
|
||||||
|
m._replace_hook(old=self, new=replace_with.name, user=use_node)
|
||||||
|
|
||||||
new_args = map_arg(use_node.args, maybe_replace_node)
|
new_args = map_arg(use_node.args, maybe_replace_node)
|
||||||
new_kwargs = map_arg(use_node.kwargs, maybe_replace_node)
|
new_kwargs = map_arg(use_node.kwargs, maybe_replace_node)
|
||||||
assert isinstance(new_args, tuple)
|
assert isinstance(new_args, tuple)
|
||||||
|
|
@ -662,6 +666,10 @@ class Node:
|
||||||
def maybe_replace_node(n : Node) -> Node:
|
def maybe_replace_node(n : Node) -> Node:
|
||||||
return new_input if n == old_input else n
|
return new_input if n == old_input else n
|
||||||
|
|
||||||
|
m = self.graph.owning_module
|
||||||
|
if getattr(m, "_replace_hook", None):
|
||||||
|
m._replace_hook(old=old_input, new=new_input.name, user=self)
|
||||||
|
|
||||||
new_args = map_arg(self.args, maybe_replace_node)
|
new_args = map_arg(self.args, maybe_replace_node)
|
||||||
new_kwargs = map_arg(self.kwargs, maybe_replace_node)
|
new_kwargs = map_arg(self.kwargs, maybe_replace_node)
|
||||||
assert isinstance(new_args, tuple)
|
assert isinstance(new_args, tuple)
|
||||||
|
|
@ -675,6 +683,15 @@ class Node:
|
||||||
self.name = name
|
self.name = name
|
||||||
self.graph._graph_namespace._rename_object(self, name)
|
self.graph._graph_namespace._rename_object(self, name)
|
||||||
|
|
||||||
|
def __setattr__(self, name: str, value: Any) -> None:
|
||||||
|
if name == 'name' and hasattr(self, "name"):
|
||||||
|
m = self.graph.owning_module
|
||||||
|
if getattr(m, "_replace_hook", None):
|
||||||
|
assert isinstance(value, str)
|
||||||
|
for user in self.users:
|
||||||
|
m._replace_hook(old=self, new=value, user=user)
|
||||||
|
object.__setattr__(self, name, value)
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:
|
def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user