[export] Upstream support of (tensor, tensor list) in op returns. (#111857)

Summary:
Upstreaming from internal to oss.
Diff: D49710320

Test Plan: buck2 build mode/opt sigmoid/inference/test_gpu:package_gen

Differential Revision: D50577490

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111857
Approved by: https://github.com/SherlockNoMad
This commit is contained in:
Zhengxu Chen 2023-10-25 21:38:07 +00:00 committed by PyTorch MergeBot
parent e5049648be
commit f2a0bef35a
2 changed files with 140 additions and 38 deletions

View File

@ -236,7 +236,7 @@ class TestDeserialize(TestCase):
elif isinstance(val1, (list, tuple)) and isinstance(val2, (list, tuple)):
# Or both are fake tensors lists with one element and with the
# same shape/dtype
for v1, v2 in zip(val1, val2):
for v1, v2 in zip(pytree.tree_flatten(val1)[0], pytree.tree_flatten(val2)[0]):
self.assertEqual(v1.shape, v2.shape)
self.assertEqual(v1.dtype, v2.dtype)
else:
@ -398,6 +398,24 @@ class TestDeserialize(TestCase):
inputs = (torch.ones(3, 2, 2), torch.ones(2))
self.check_graph(g, inputs, _check_meta=False)
def test_tensor_tensor_list(self):
from torch.library import Library
lib = Library("_export", "FRAGMENT")
lib.define("_test_tensor_tensor_list_output(Tensor x, Tensor y) -> (Tensor, Tensor[])")
def _test_tensor_tensor_list_output(x, y):
return y, [x]
lib.impl("_test_tensor_tensor_list_output", _test_tensor_tensor_list_output, "CPU")
lib.impl("_test_tensor_tensor_list_output", _test_tensor_tensor_list_output, "Meta")
class M(torch.nn.Module):
def forward(self, x, y):
a, b = torch.ops._export._test_tensor_tensor_list_output.default(x, y)
return a + b[0]
self.check_graph(M(), (torch.rand(3, 2), torch.rand(3, 2)))
@parametrize(
"name,case",
get_filtered_export_db_tests(),

View File

@ -11,7 +11,17 @@ import typing
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union
from typing import (
Any,
Callable,
cast,
Dict,
Iterator,
List,
Optional,
Tuple,
Union,
)
import sympy
@ -38,10 +48,8 @@ from .schema import ( # type: ignore[attr-defined]
InputSpec,
InputToBufferSpec,
InputToParameterSpec,
UserInputSpec,
UserOutputSpec,
LossOutputSpec,
Layout,
LossOutputSpec,
MemoryFormat,
ModuleCallEntry,
ModuleCallSignature,
@ -61,6 +69,8 @@ from .schema import ( # type: ignore[attr-defined]
TensorMeta,
TensorValue,
TREESPEC_VERSION,
UserInputSpec,
UserOutputSpec,
)
@ -284,6 +294,16 @@ def _is_single_tensor_return(target: torch._ops.OpOverload) -> bool:
return len(returns) == 1 and isinstance(returns[0].real_type, torch.TensorType)
def _is_single_tensor_list_return(target: torch._ops.OpOverload) -> bool:
returns = target._schema.returns
if len(returns) != 1:
return False
return_type = returns[0].real_type
return isinstance(return_type, torch.ListType) and isinstance(
return_type.getElementType(), torch.TensorType
)
@dataclass
class GraphState:
inputs: List[Argument] = field(default_factory=list)
@ -800,38 +820,81 @@ class GraphModuleSerializer:
meta_val = node.meta["val"]
def output_node_at_index(node, index):
for user in node.users:
assert user.target is operator.getitem, f"{user} is not a getitem node"
if index == user.args[1]:
return user
return None
# Check single value return
if _is_single_tensor_return(node.target):
# e.g "-> Tensor"
return [Argument.create(as_tensor=self.serialize_tensor_output(node.name, meta_val))]
elif len(returns) == 1 and isinstance(meta_val, torch.SymInt):
# e.g "-> SymInt"
return [Argument.create(as_sym_int=self.serialize_sym_int_output(node.name, meta_val))]
elif len(returns) == 1 and isinstance(meta_val, torch.SymBool):
# e.g "-> SymBool"
return [Argument.create(as_sym_bool=self.serialize_sym_bool_output(node.name, meta_val))]
elif _is_single_tensor_list_return(node.target):
# e.g "-> Tensor[]"
tensor_args = []
for idx, meta in enumerate(meta_val):
user_node = output_node_at_index(node, idx)
name = (
user_node.name
if user_node is not None
else f"{node.name}_unused_{idx}"
)
tensor_args.append(self.serialize_tensor_output(name, meta))
return [Argument.create(as_tensors=tensor_args)]
# There are a two possibilities at this point:
# - This operator returns a list of Tensors.
# - This operator returns multiple Tensors.
# - This operator returns a tuple of Tensors, e.g. "-> (Tensor, Tensor)"
# - This operator returns a tuple of mixed of Tensor and Tensors, e.g. "-> (Tensor, Tensor[])"
#
# Either way, start by gathering a list of TensorArguments with the correct names.
# For consistent naming with FX, consult the downstream `getitem` node and
# make sure our outputs have the same name.
arg_list = self._handle_getitem_users(node)
output_arguments = []
for idx, (meta, return_schema) in enumerate(zip(meta_val, returns)):
if meta is None:
assert isinstance(return_schema.real_type, torch.OptionalType)
output_arguments.append(Argument.create(as_none=()))
elif isinstance(meta, torch._subclasses.fake_tensor.FakeTensor):
assert isinstance(return_schema.real_type, torch.TensorType)
user_node = output_node_at_index(node, idx)
name = (
user_node.name
if user_node is not None
else f"{node.name}_unused_{idx}"
)
output_arguments.append(
Argument.create(as_tensor=self.serialize_tensor_output(name, meta))
)
elif isinstance(meta, list):
# for List[Tensor] return type
assert isinstance(
return_schema.real_type, torch.ListType
) and isinstance(
return_schema.real_type.getElementType(), torch.TensorType
)
user_node = output_node_at_index(node, idx)
assert user_node is not None
# Then, pack the return value differently depending on what the return type is.
if len(returns) == 1:
return_type = returns[0].real_type
assert isinstance(return_type, torch.ListType) and isinstance(
return_type.getElementType(), torch.TensorType
), "Only tensors and lists of tensors supported"
args = []
for i, m in enumerate(meta):
if m is None:
continue
sub_user_node = output_node_at_index(user_node, i)
assert sub_user_node is not None, f"No user found at index {i}"
return [Argument.create(as_tensors=arg_list)]
else:
assert all(
isinstance(ret.real_type, torch.TensorType) for ret in returns
), f"Multiple returns can only have tensor returns, got: {[ret.real_type for ret in returns]}"
args.append(self.serialize_tensor_output(sub_user_node.name, m))
output_arguments.append(Argument.create(as_tensors=args))
return [Argument.create(as_tensor=arg) for arg in arg_list]
return output_arguments
def _handle_getitem_users(self, node: torch.fx.Node) -> List[TensorArgument]:
meta_val = node.meta["val"]
@ -1364,23 +1427,10 @@ class GraphModuleDeserializer:
self.deserialize_multiple_outputs(serialized_node, fx_node)
def deserialize_multiple_outputs(self, serialized_node: Node, fx_node: torch.fx.Node) -> None:
# Convert multiple return types to FX format.
# In FX, each node only returns one value. So in order to represent
# multiple return values, we have to emit a `getitem` node for each
# return value.
# This performs the inverse mapping of the `serialize_outputs` call in
# serialization, see [NOTE: Multiple outputs]
output_names = []
if len(serialized_node.outputs) == 1:
assert isinstance(serialized_node.outputs[0].value, list)
assert isinstance(serialized_node.outputs[0].value[0], TensorArgument)
output_names = [arg.name for arg in serialized_node.outputs[0].as_tensors]
else:
for output in serialized_node.outputs:
assert isinstance(output.value, TensorArgument)
output_names.append(output.as_tensor.name)
deserialized_metadata = self.deserialize_metadata(serialized_node.metadata)
for idx, name in enumerate(output_names):
def generate_getitem(meta_val, fx_node: torch.fx.Node, arg: TensorArgument, idx: int):
name = arg.name
individual_output = self.graph.create_node(
"call_function",
operator.getitem,
@ -1388,12 +1438,46 @@ class GraphModuleDeserializer:
name=name,
)
self.sync_fx_node(name, individual_output)
meta_val.append(self.serialized_name_to_meta[name])
# The derived `getitem` nodes should have the same stacktrace as the
# original `fx_node`
individual_output.meta.update(self.deserialize_metadata(serialized_node.metadata))
individual_output.meta.update(deserialized_metadata)
def generate_getitems(meta_val, fx_node: torch.fx.Node, args):
for idx, arg in enumerate(args):
if isinstance(arg, Argument):
arg = arg.value
if isinstance(arg, TensorArgument):
generate_getitem(meta_val, fx_node, arg, idx)
elif isinstance(arg, (list, tuple)):
list_output = self.graph.create_node(
"call_function",
operator.getitem,
(fx_node, idx),
)
meta_val.append([])
generate_getitems(meta_val[-1], list_output, arg)
list_output.meta.update(deserialized_metadata)
list_output.meta['val'] = meta_val[-1]
else:
raise NotImplementedError(f"Unimplemented node output type: {arg}")
# Convert multiple return types to FX format.
# In FX, each node only returns one value. So in order to represent
# multiple return values, we have to emit a `getitem` node for each
# return value.
# This performs the inverse mapping of the `serialize_outputs` call in
# serialization, see [NOTE: Multiple outputs]
meta_val: List[Any] = []
if len(serialized_node.outputs) == 1:
assert isinstance(serialized_node.outputs[0].value, list)
assert isinstance(serialized_node.outputs[0].value[0], TensorArgument)
generate_getitems(meta_val, fx_node, serialized_node.outputs[0].as_tensors)
else:
generate_getitems(meta_val, fx_node, serialized_node.outputs)
# also update the metaval for `fx_node` to be a list(meta)
fx_node.meta["val"] = tuple(self.serialized_name_to_meta[name] for name in output_names)
fx_node.meta["val"] = tuple(meta_val)
self.serialized_name_to_node[fx_node.name] = fx_node
def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]: