[export] Fix handling output in remove_effect_tokens_pass (#122357)

Added handling for updating the output_spec in the graph signature if the the result of a with_effects call is an output.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122357
Approved by: https://github.com/zhxchen17
This commit is contained in:
angelayi 2024-03-22 03:35:59 +00:00 committed by PyTorch MergeBot
parent 09eb07bee8
commit fb57d1699b
3 changed files with 154 additions and 99 deletions

View File

@ -34,19 +34,25 @@ from torch._export.utils import (
sequential_split,
)
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
from torch.export import export
from torch.export._remove_auto_functionalized_pass import (
unsafe_remove_auto_functionalized_pass,
)
from torch.export._remove_effect_tokens_pass import _remove_effect_tokens
from torch.fx.passes.infra.partitioner import Partition
from torch.fx.passes.operator_support import OperatorSupport
from torch.library import impl, _scoped_library
from torch.testing import FileCheck
from torch.testing._internal.common_utils import (
IS_WINDOWS,
find_library_location,
run_tests,
skipIfTorchDynamo,
TestCase,
skipIfTorchDynamo,
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
)
from torch.utils import _pytree as pytree
@ -194,6 +200,18 @@ class TestPasses(TestCase):
self.SEQUENTIAL_SPLIT_INLINE_TESTS = _sequential_split_inline_tests()
self.SET_GRAD_ENABLED_TESTS = _set_grad_enabled_tests()
if IS_SANDCASTLE or IS_FBCODE:
torch.ops.load_library(
"//caffe2/test/cpp/jit:test_custom_class_registrations"
)
elif IS_MACOS:
raise unittest.SkipTest("non-portable load_library call used in test")
else:
lib_file_path = find_library_location('libtorchbind_test.so')
if IS_WINDOWS:
lib_file_path = find_library_location('torchbind_test.dll')
torch.ops.load_library(str(lib_file_path))
def tearDown(self):
self.SEQUENTIAL_SPLIT_INLINE_TESTS.clear()
self.SET_GRAD_ENABLED_TESTS.clear()
@ -354,6 +372,34 @@ class TestPasses(TestCase):
if torch.Tag.core in op_overload.tags and is_view_op(op_overload._schema):
self.assertIsNotNone(get_view_copy_of_view_op(op_overload._schema))
def test_custom_obj_tuple_out(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x)
y = a[0] + a[1]
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
return b
m = MyModule()
inputs = (torch.ones(2, 3),)
with enable_torchbind_tracing():
ep = torch.export.export(m, inputs, strict=False)
inp = torch.randn(2, 3)
orig_res = m(inp)
ep_res = ep.module()(inp)
without_token_ep = _remove_effect_tokens(ep)
without_token_ep.verifier().check(without_token_ep)
without_token_res = without_token_ep.module()(inp)
self.assertTrue(torch.allclose(orig_res, ep_res))
self.assertTrue(torch.allclose(orig_res, without_token_res))
def test_runtime_assert_inline_constraints_for_item(self) -> None:
class M(torch.nn.Module):
def __init__(self):

View File

@ -15,6 +15,57 @@ from torch._higher_order_ops.auto_functionalize import (
from torch.export import ExportedProgram
def _remove_auto_functionalization_from_graph_helper(ep, auto_functionalize_nodes):
# Update every use of the HOP
for node in reversed(auto_functionalize_nodes):
func = node.args[0]
original_kwargs = node.kwargs
assert isinstance(func, torch._ops.OpOverload)
with ep.graph.inserting_before(node):
# This makes the call_function refer to every arg as a kwarg, this is weird but probably fine?
new_node = ep.graph.call_function(func, kwargs=node.kwargs)
for k, v in node.meta.items():
new_node.meta[k] = v
# Replace auto_functionalize(func, args) with just func(args)
node.replace_all_uses_with(new_node)
mutable_args_names = get_mutable_arg_names(new_node.target)
# update the users of the auto_func node (the getitem nodes)
for user in list(new_node.users.keys()):
assert user.target == operator.getitem
# getitem corresponding to a mutated input, just replace all uses with the original input
if user.args[1] >= len(func._schema.returns):
assert user.args[1] <= len(func._schema.returns) + len(
mutable_args_names
)
# If the result of getitem was used in an output node, update the output spec with the correct name
adusted_index = user.args[1] - len(func._schema.returns)
original_arg = original_kwargs[mutable_args_names[adusted_index]]
# This is a little fragile/implementation dependent, but the order of the mutable args is the same as the order
# of the getitem calls following the HOP.
user.replace_all_uses_with(
original_kwargs[mutable_args_names[adusted_index]]
)
if len(func._schema.returns) == 1:
# If the function has 1 return then it will just directly return the
# result -- we don't need a getitem. So we can replace all the
# getitem(auto_functionalized, 0) with just the note itself.
for user in list(new_node.users.keys()):
if user.args[1] == 0:
user.replace_all_uses_with(new_node)
new_node.meta["val"] = node.meta["val"][: len(func._schema.returns)]
ep.graph.erase_node(node)
ep.graph.eliminate_dead_code()
def unsafe_remove_auto_functionalized_pass(
ep: ExportedProgram,
) -> ExportedProgram:
@ -31,63 +82,7 @@ def unsafe_remove_auto_functionalized_pass(
if node.op == "call_function" and node.target is auto_functionalized:
auto_functionalize_nodes.append(node)
# Update every use of the HOP
for node in reversed(auto_functionalize_nodes):
func = node.args[0]
original_kwargs = node.kwargs
assert isinstance(func, torch._ops.OpOverload)
with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()):
_remove_auto_functionalization_from_graph_helper(ep, auto_functionalize_nodes)
with ep.graph.inserting_before(node):
# This makes the call_function refer to every arg as a kwarg, this is weird but probably fine?
new_node = ep.graph.call_function(func, kwargs=node.kwargs)
for k, v in node.meta.items():
new_node.meta[k] = v
# Replace auto_functionalize(func, args) with just func(args)
node.replace_all_uses_with(new_node)
mutable_args_names = get_mutable_arg_names(new_node.target)
output_specs = ep.graph_signature.output_specs
# update the users of the auto_func node (the getitem nodes)
for user in list(new_node.users.keys()):
assert user.target == operator.getitem
# getitem corresponding to a mutated input, just replace all uses with the original input
if user.args[1] >= len(func._schema.returns):
assert user.args[1] <= len(func._schema.returns) + len(
mutable_args_names
)
# If the result of getitem was used in an output node, update the output spec with the correct name
adusted_index = user.args[1] - len(func._schema.returns)
original_arg = original_kwargs[mutable_args_names[adusted_index]]
for spec in output_specs:
if spec.arg.name == user.name:
spec.arg.name = original_arg.name # pyre-ignore
break
# This is a little fragile/implementation dependent, but the order of the mutable args is the same as the order
# of the getitem calls following the HOP.
user.replace_all_uses_with(
original_kwargs[mutable_args_names[adusted_index]]
)
if len(func._schema.returns) == 1:
# If the function has 1 return then it will just directly return the
# result -- we don't need a getitem. So we can replace all the
# getitem(auto_functionalized, 0) with just the note itself.
for user in list(new_node.users.keys()):
if user.args[1] == 0:
user.replace_all_uses_with(new_node)
# Same case as above, update the output spec if getitem result used in an output node
for spec in output_specs:
if spec.arg.name == user.name:
spec.arg.name = new_node.name
break
new_node.meta["val"] = node.meta["val"][: len(func._schema.returns)]
ep.graph.erase_node(node)
ep.graph.eliminate_dead_code()
return ep

View File

@ -7,48 +7,24 @@ from .exported_program import ExportedProgram
from .graph_signature import InputKind, InputSpec, OutputKind, OutputSpec, TokenArgument
def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
"""
Removes the existance of tokens from the exported program, including:
- Removes the input and output tokens
- Replaces with_effects(token, func, args) with just func(args)
This function does an inplace modification on the given ExportedProgram.
"""
num_tokens: int = 0
input_token_names: List[str] = []
new_input_specs: List[InputSpec] = []
for inp in ep.graph_signature.input_specs:
if inp.kind == InputKind.TOKEN:
num_tokens += 1
assert isinstance(inp.arg, TokenArgument)
input_token_names.append(inp.arg.name)
else:
new_input_specs.append(inp)
num_out_tokens: int = 0
new_output_specs: List[str] = []
output_token_names: List[OutputSpec] = []
for out in ep.graph_signature.output_specs:
if out.kind == OutputKind.TOKEN:
num_out_tokens += 1
output_token_names.append(out.arg.name)
else:
new_output_specs.append(out)
assert num_tokens == num_out_tokens
def _remove_effect_tokens_from_graph_helper(
ep, num_tokens, input_token_names, output_token_names
):
output_node = None
with_effect_nodes: List[torch.fx.Node] = []
for node in ep.graph.nodes:
if node.op == "output":
output_node = node
break
if not (node.op == "call_function" and node.target is with_effects):
for module in ep.graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
with_effect_nodes.append(node)
for node in ep.graph.nodes:
if node.op == "output":
output_node = node
break
if not (node.op == "call_function" and node.target is with_effects):
continue
with_effect_nodes.append(node)
# Remove tokens from outputs
assert output_node is not None
@ -112,9 +88,47 @@ def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
assert inp_token.name in input_token_names
ep.graph.erase_node(inp_token)
ep.graph.eliminate_dead_code()
def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
"""
Removes the existance of tokens from the exported program, including:
- Removes the input and output tokens
- Replaces with_effects(token, func, args) with just func(args)
This function does an inplace modification on the given ExportedProgram.
"""
num_tokens: int = 0
input_token_names: List[str] = []
new_input_specs: List[InputSpec] = []
for inp in ep.graph_signature.input_specs:
if inp.kind == InputKind.TOKEN:
num_tokens += 1
assert isinstance(inp.arg, TokenArgument)
input_token_names.append(inp.arg.name)
else:
new_input_specs.append(inp)
num_out_tokens: int = 0
new_output_specs: List[OutputSpec] = []
output_token_names: List[OutputSpec] = []
for out in ep.graph_signature.output_specs:
if out.kind == OutputKind.TOKEN:
num_out_tokens += 1
output_token_names.append(out.arg.name)
else:
new_output_specs.append(out)
# Update graph signature
ep.graph_signature.input_specs = new_input_specs
ep.graph_signature.output_specs = new_output_specs
ep.graph.eliminate_dead_code()
assert num_tokens == num_out_tokens
with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()):
_remove_effect_tokens_from_graph_helper(
ep, num_tokens, input_token_names, output_token_names
)
return ep