mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] Fix handling output in remove_effect_tokens_pass (#122357)
Added handling for updating the output_spec in the graph signature if the the result of a with_effects call is an output. Pull Request resolved: https://github.com/pytorch/pytorch/pull/122357 Approved by: https://github.com/zhxchen17
This commit is contained in:
parent
09eb07bee8
commit
fb57d1699b
|
|
@ -34,19 +34,25 @@ from torch._export.utils import (
|
|||
sequential_split,
|
||||
)
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
|
||||
from torch.export import export
|
||||
from torch.export._remove_auto_functionalized_pass import (
|
||||
unsafe_remove_auto_functionalized_pass,
|
||||
)
|
||||
from torch.export._remove_effect_tokens_pass import _remove_effect_tokens
|
||||
from torch.fx.passes.infra.partitioner import Partition
|
||||
from torch.fx.passes.operator_support import OperatorSupport
|
||||
from torch.library import impl, _scoped_library
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
find_library_location,
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
TestCase,
|
||||
skipIfTorchDynamo,
|
||||
IS_FBCODE,
|
||||
IS_MACOS,
|
||||
IS_SANDCASTLE,
|
||||
IS_WINDOWS,
|
||||
)
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
|
@ -194,6 +200,18 @@ class TestPasses(TestCase):
|
|||
self.SEQUENTIAL_SPLIT_INLINE_TESTS = _sequential_split_inline_tests()
|
||||
self.SET_GRAD_ENABLED_TESTS = _set_grad_enabled_tests()
|
||||
|
||||
if IS_SANDCASTLE or IS_FBCODE:
|
||||
torch.ops.load_library(
|
||||
"//caffe2/test/cpp/jit:test_custom_class_registrations"
|
||||
)
|
||||
elif IS_MACOS:
|
||||
raise unittest.SkipTest("non-portable load_library call used in test")
|
||||
else:
|
||||
lib_file_path = find_library_location('libtorchbind_test.so')
|
||||
if IS_WINDOWS:
|
||||
lib_file_path = find_library_location('torchbind_test.dll')
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
|
||||
def tearDown(self):
|
||||
self.SEQUENTIAL_SPLIT_INLINE_TESTS.clear()
|
||||
self.SET_GRAD_ENABLED_TESTS.clear()
|
||||
|
|
@ -354,6 +372,34 @@ class TestPasses(TestCase):
|
|||
if torch.Tag.core in op_overload.tags and is_view_op(op_overload._schema):
|
||||
self.assertIsNotNone(get_view_copy_of_view_op(op_overload._schema))
|
||||
|
||||
def test_custom_obj_tuple_out(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
|
||||
def forward(self, x):
|
||||
a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x)
|
||||
y = a[0] + a[1]
|
||||
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
|
||||
return b
|
||||
|
||||
m = MyModule()
|
||||
inputs = (torch.ones(2, 3),)
|
||||
with enable_torchbind_tracing():
|
||||
ep = torch.export.export(m, inputs, strict=False)
|
||||
|
||||
inp = torch.randn(2, 3)
|
||||
orig_res = m(inp)
|
||||
ep_res = ep.module()(inp)
|
||||
|
||||
without_token_ep = _remove_effect_tokens(ep)
|
||||
without_token_ep.verifier().check(without_token_ep)
|
||||
without_token_res = without_token_ep.module()(inp)
|
||||
|
||||
self.assertTrue(torch.allclose(orig_res, ep_res))
|
||||
self.assertTrue(torch.allclose(orig_res, without_token_res))
|
||||
|
||||
def test_runtime_assert_inline_constraints_for_item(self) -> None:
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -15,6 +15,57 @@ from torch._higher_order_ops.auto_functionalize import (
|
|||
from torch.export import ExportedProgram
|
||||
|
||||
|
||||
def _remove_auto_functionalization_from_graph_helper(ep, auto_functionalize_nodes):
|
||||
# Update every use of the HOP
|
||||
for node in reversed(auto_functionalize_nodes):
|
||||
func = node.args[0]
|
||||
original_kwargs = node.kwargs
|
||||
assert isinstance(func, torch._ops.OpOverload)
|
||||
|
||||
with ep.graph.inserting_before(node):
|
||||
# This makes the call_function refer to every arg as a kwarg, this is weird but probably fine?
|
||||
new_node = ep.graph.call_function(func, kwargs=node.kwargs)
|
||||
for k, v in node.meta.items():
|
||||
new_node.meta[k] = v
|
||||
|
||||
# Replace auto_functionalize(func, args) with just func(args)
|
||||
node.replace_all_uses_with(new_node)
|
||||
|
||||
mutable_args_names = get_mutable_arg_names(new_node.target)
|
||||
|
||||
# update the users of the auto_func node (the getitem nodes)
|
||||
for user in list(new_node.users.keys()):
|
||||
assert user.target == operator.getitem
|
||||
# getitem corresponding to a mutated input, just replace all uses with the original input
|
||||
if user.args[1] >= len(func._schema.returns):
|
||||
assert user.args[1] <= len(func._schema.returns) + len(
|
||||
mutable_args_names
|
||||
)
|
||||
|
||||
# If the result of getitem was used in an output node, update the output spec with the correct name
|
||||
adusted_index = user.args[1] - len(func._schema.returns)
|
||||
original_arg = original_kwargs[mutable_args_names[adusted_index]]
|
||||
|
||||
# This is a little fragile/implementation dependent, but the order of the mutable args is the same as the order
|
||||
# of the getitem calls following the HOP.
|
||||
user.replace_all_uses_with(
|
||||
original_kwargs[mutable_args_names[adusted_index]]
|
||||
)
|
||||
|
||||
if len(func._schema.returns) == 1:
|
||||
# If the function has 1 return then it will just directly return the
|
||||
# result -- we don't need a getitem. So we can replace all the
|
||||
# getitem(auto_functionalized, 0) with just the note itself.
|
||||
for user in list(new_node.users.keys()):
|
||||
if user.args[1] == 0:
|
||||
user.replace_all_uses_with(new_node)
|
||||
|
||||
new_node.meta["val"] = node.meta["val"][: len(func._schema.returns)]
|
||||
ep.graph.erase_node(node)
|
||||
|
||||
ep.graph.eliminate_dead_code()
|
||||
|
||||
|
||||
def unsafe_remove_auto_functionalized_pass(
|
||||
ep: ExportedProgram,
|
||||
) -> ExportedProgram:
|
||||
|
|
@ -31,63 +82,7 @@ def unsafe_remove_auto_functionalized_pass(
|
|||
if node.op == "call_function" and node.target is auto_functionalized:
|
||||
auto_functionalize_nodes.append(node)
|
||||
|
||||
# Update every use of the HOP
|
||||
for node in reversed(auto_functionalize_nodes):
|
||||
func = node.args[0]
|
||||
original_kwargs = node.kwargs
|
||||
assert isinstance(func, torch._ops.OpOverload)
|
||||
with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()):
|
||||
_remove_auto_functionalization_from_graph_helper(ep, auto_functionalize_nodes)
|
||||
|
||||
with ep.graph.inserting_before(node):
|
||||
# This makes the call_function refer to every arg as a kwarg, this is weird but probably fine?
|
||||
new_node = ep.graph.call_function(func, kwargs=node.kwargs)
|
||||
for k, v in node.meta.items():
|
||||
new_node.meta[k] = v
|
||||
|
||||
# Replace auto_functionalize(func, args) with just func(args)
|
||||
node.replace_all_uses_with(new_node)
|
||||
|
||||
mutable_args_names = get_mutable_arg_names(new_node.target)
|
||||
output_specs = ep.graph_signature.output_specs
|
||||
|
||||
# update the users of the auto_func node (the getitem nodes)
|
||||
for user in list(new_node.users.keys()):
|
||||
assert user.target == operator.getitem
|
||||
# getitem corresponding to a mutated input, just replace all uses with the original input
|
||||
if user.args[1] >= len(func._schema.returns):
|
||||
assert user.args[1] <= len(func._schema.returns) + len(
|
||||
mutable_args_names
|
||||
)
|
||||
|
||||
# If the result of getitem was used in an output node, update the output spec with the correct name
|
||||
adusted_index = user.args[1] - len(func._schema.returns)
|
||||
original_arg = original_kwargs[mutable_args_names[adusted_index]]
|
||||
for spec in output_specs:
|
||||
if spec.arg.name == user.name:
|
||||
spec.arg.name = original_arg.name # pyre-ignore
|
||||
break
|
||||
|
||||
# This is a little fragile/implementation dependent, but the order of the mutable args is the same as the order
|
||||
# of the getitem calls following the HOP.
|
||||
user.replace_all_uses_with(
|
||||
original_kwargs[mutable_args_names[adusted_index]]
|
||||
)
|
||||
|
||||
if len(func._schema.returns) == 1:
|
||||
# If the function has 1 return then it will just directly return the
|
||||
# result -- we don't need a getitem. So we can replace all the
|
||||
# getitem(auto_functionalized, 0) with just the note itself.
|
||||
for user in list(new_node.users.keys()):
|
||||
if user.args[1] == 0:
|
||||
user.replace_all_uses_with(new_node)
|
||||
|
||||
# Same case as above, update the output spec if getitem result used in an output node
|
||||
for spec in output_specs:
|
||||
if spec.arg.name == user.name:
|
||||
spec.arg.name = new_node.name
|
||||
break
|
||||
|
||||
new_node.meta["val"] = node.meta["val"][: len(func._schema.returns)]
|
||||
ep.graph.erase_node(node)
|
||||
|
||||
ep.graph.eliminate_dead_code()
|
||||
return ep
|
||||
|
|
|
|||
|
|
@ -7,48 +7,24 @@ from .exported_program import ExportedProgram
|
|||
from .graph_signature import InputKind, InputSpec, OutputKind, OutputSpec, TokenArgument
|
||||
|
||||
|
||||
def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
|
||||
"""
|
||||
Removes the existance of tokens from the exported program, including:
|
||||
- Removes the input and output tokens
|
||||
- Replaces with_effects(token, func, args) with just func(args)
|
||||
|
||||
This function does an inplace modification on the given ExportedProgram.
|
||||
"""
|
||||
num_tokens: int = 0
|
||||
input_token_names: List[str] = []
|
||||
new_input_specs: List[InputSpec] = []
|
||||
for inp in ep.graph_signature.input_specs:
|
||||
if inp.kind == InputKind.TOKEN:
|
||||
num_tokens += 1
|
||||
assert isinstance(inp.arg, TokenArgument)
|
||||
input_token_names.append(inp.arg.name)
|
||||
else:
|
||||
new_input_specs.append(inp)
|
||||
|
||||
num_out_tokens: int = 0
|
||||
new_output_specs: List[str] = []
|
||||
output_token_names: List[OutputSpec] = []
|
||||
for out in ep.graph_signature.output_specs:
|
||||
if out.kind == OutputKind.TOKEN:
|
||||
num_out_tokens += 1
|
||||
output_token_names.append(out.arg.name)
|
||||
else:
|
||||
new_output_specs.append(out)
|
||||
|
||||
assert num_tokens == num_out_tokens
|
||||
|
||||
def _remove_effect_tokens_from_graph_helper(
|
||||
ep, num_tokens, input_token_names, output_token_names
|
||||
):
|
||||
output_node = None
|
||||
with_effect_nodes: List[torch.fx.Node] = []
|
||||
for node in ep.graph.nodes:
|
||||
if node.op == "output":
|
||||
output_node = node
|
||||
break
|
||||
|
||||
if not (node.op == "call_function" and node.target is with_effects):
|
||||
for module in ep.graph_module.modules():
|
||||
if not isinstance(module, torch.fx.GraphModule):
|
||||
continue
|
||||
|
||||
with_effect_nodes.append(node)
|
||||
for node in ep.graph.nodes:
|
||||
if node.op == "output":
|
||||
output_node = node
|
||||
break
|
||||
|
||||
if not (node.op == "call_function" and node.target is with_effects):
|
||||
continue
|
||||
|
||||
with_effect_nodes.append(node)
|
||||
|
||||
# Remove tokens from outputs
|
||||
assert output_node is not None
|
||||
|
|
@ -112,9 +88,47 @@ def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
|
|||
assert inp_token.name in input_token_names
|
||||
ep.graph.erase_node(inp_token)
|
||||
|
||||
ep.graph.eliminate_dead_code()
|
||||
|
||||
|
||||
def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
|
||||
"""
|
||||
Removes the existance of tokens from the exported program, including:
|
||||
- Removes the input and output tokens
|
||||
- Replaces with_effects(token, func, args) with just func(args)
|
||||
|
||||
This function does an inplace modification on the given ExportedProgram.
|
||||
"""
|
||||
num_tokens: int = 0
|
||||
input_token_names: List[str] = []
|
||||
new_input_specs: List[InputSpec] = []
|
||||
for inp in ep.graph_signature.input_specs:
|
||||
if inp.kind == InputKind.TOKEN:
|
||||
num_tokens += 1
|
||||
assert isinstance(inp.arg, TokenArgument)
|
||||
input_token_names.append(inp.arg.name)
|
||||
else:
|
||||
new_input_specs.append(inp)
|
||||
|
||||
num_out_tokens: int = 0
|
||||
new_output_specs: List[OutputSpec] = []
|
||||
output_token_names: List[OutputSpec] = []
|
||||
for out in ep.graph_signature.output_specs:
|
||||
if out.kind == OutputKind.TOKEN:
|
||||
num_out_tokens += 1
|
||||
output_token_names.append(out.arg.name)
|
||||
else:
|
||||
new_output_specs.append(out)
|
||||
|
||||
# Update graph signature
|
||||
ep.graph_signature.input_specs = new_input_specs
|
||||
ep.graph_signature.output_specs = new_output_specs
|
||||
|
||||
ep.graph.eliminate_dead_code()
|
||||
assert num_tokens == num_out_tokens
|
||||
|
||||
with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()):
|
||||
_remove_effect_tokens_from_graph_helper(
|
||||
ep, num_tokens, input_token_names, output_token_names
|
||||
)
|
||||
|
||||
return ep
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user