mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[export] Upstream support of (tensor, tensor list) in op returns. (#111857)
Summary: Upstreaming from internal to oss. Diff: D49710320 Test Plan: buck2 build mode/opt sigmoid/inference/test_gpu:package_gen Differential Revision: D50577490 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111857 Approved by: https://github.com/SherlockNoMad
This commit is contained in:
parent
e5049648be
commit
f2a0bef35a
|
|
@ -236,7 +236,7 @@ class TestDeserialize(TestCase):
|
|||
elif isinstance(val1, (list, tuple)) and isinstance(val2, (list, tuple)):
|
||||
# Or both are fake tensors lists with one element and with the
|
||||
# same shape/dtype
|
||||
for v1, v2 in zip(val1, val2):
|
||||
for v1, v2 in zip(pytree.tree_flatten(val1)[0], pytree.tree_flatten(val2)[0]):
|
||||
self.assertEqual(v1.shape, v2.shape)
|
||||
self.assertEqual(v1.dtype, v2.dtype)
|
||||
else:
|
||||
|
|
@ -398,6 +398,24 @@ class TestDeserialize(TestCase):
|
|||
inputs = (torch.ones(3, 2, 2), torch.ones(2))
|
||||
self.check_graph(g, inputs, _check_meta=False)
|
||||
|
||||
def test_tensor_tensor_list(self):
|
||||
from torch.library import Library
|
||||
lib = Library("_export", "FRAGMENT")
|
||||
lib.define("_test_tensor_tensor_list_output(Tensor x, Tensor y) -> (Tensor, Tensor[])")
|
||||
|
||||
def _test_tensor_tensor_list_output(x, y):
|
||||
return y, [x]
|
||||
|
||||
lib.impl("_test_tensor_tensor_list_output", _test_tensor_tensor_list_output, "CPU")
|
||||
lib.impl("_test_tensor_tensor_list_output", _test_tensor_tensor_list_output, "Meta")
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
a, b = torch.ops._export._test_tensor_tensor_list_output.default(x, y)
|
||||
return a + b[0]
|
||||
|
||||
self.check_graph(M(), (torch.rand(3, 2), torch.rand(3, 2)))
|
||||
|
||||
@parametrize(
|
||||
"name,case",
|
||||
get_filtered_export_db_tests(),
|
||||
|
|
|
|||
|
|
@ -11,7 +11,17 @@ import typing
|
|||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import sympy
|
||||
|
||||
|
|
@ -38,10 +48,8 @@ from .schema import ( # type: ignore[attr-defined]
|
|||
InputSpec,
|
||||
InputToBufferSpec,
|
||||
InputToParameterSpec,
|
||||
UserInputSpec,
|
||||
UserOutputSpec,
|
||||
LossOutputSpec,
|
||||
Layout,
|
||||
LossOutputSpec,
|
||||
MemoryFormat,
|
||||
ModuleCallEntry,
|
||||
ModuleCallSignature,
|
||||
|
|
@ -61,6 +69,8 @@ from .schema import ( # type: ignore[attr-defined]
|
|||
TensorMeta,
|
||||
TensorValue,
|
||||
TREESPEC_VERSION,
|
||||
UserInputSpec,
|
||||
UserOutputSpec,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -284,6 +294,16 @@ def _is_single_tensor_return(target: torch._ops.OpOverload) -> bool:
|
|||
return len(returns) == 1 and isinstance(returns[0].real_type, torch.TensorType)
|
||||
|
||||
|
||||
def _is_single_tensor_list_return(target: torch._ops.OpOverload) -> bool:
|
||||
returns = target._schema.returns
|
||||
if len(returns) != 1:
|
||||
return False
|
||||
return_type = returns[0].real_type
|
||||
return isinstance(return_type, torch.ListType) and isinstance(
|
||||
return_type.getElementType(), torch.TensorType
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphState:
|
||||
inputs: List[Argument] = field(default_factory=list)
|
||||
|
|
@ -800,38 +820,81 @@ class GraphModuleSerializer:
|
|||
|
||||
meta_val = node.meta["val"]
|
||||
|
||||
def output_node_at_index(node, index):
|
||||
for user in node.users:
|
||||
assert user.target is operator.getitem, f"{user} is not a getitem node"
|
||||
if index == user.args[1]:
|
||||
return user
|
||||
return None
|
||||
|
||||
# Check single value return
|
||||
if _is_single_tensor_return(node.target):
|
||||
# e.g "-> Tensor"
|
||||
return [Argument.create(as_tensor=self.serialize_tensor_output(node.name, meta_val))]
|
||||
elif len(returns) == 1 and isinstance(meta_val, torch.SymInt):
|
||||
# e.g "-> SymInt"
|
||||
return [Argument.create(as_sym_int=self.serialize_sym_int_output(node.name, meta_val))]
|
||||
elif len(returns) == 1 and isinstance(meta_val, torch.SymBool):
|
||||
# e.g "-> SymBool"
|
||||
return [Argument.create(as_sym_bool=self.serialize_sym_bool_output(node.name, meta_val))]
|
||||
elif _is_single_tensor_list_return(node.target):
|
||||
# e.g "-> Tensor[]"
|
||||
tensor_args = []
|
||||
for idx, meta in enumerate(meta_val):
|
||||
user_node = output_node_at_index(node, idx)
|
||||
name = (
|
||||
user_node.name
|
||||
if user_node is not None
|
||||
else f"{node.name}_unused_{idx}"
|
||||
)
|
||||
tensor_args.append(self.serialize_tensor_output(name, meta))
|
||||
return [Argument.create(as_tensors=tensor_args)]
|
||||
|
||||
# There are a two possibilities at this point:
|
||||
# - This operator returns a list of Tensors.
|
||||
# - This operator returns multiple Tensors.
|
||||
# - This operator returns a tuple of Tensors, e.g. "-> (Tensor, Tensor)"
|
||||
# - This operator returns a tuple of mixed of Tensor and Tensors, e.g. "-> (Tensor, Tensor[])"
|
||||
#
|
||||
# Either way, start by gathering a list of TensorArguments with the correct names.
|
||||
# For consistent naming with FX, consult the downstream `getitem` node and
|
||||
# make sure our outputs have the same name.
|
||||
|
||||
arg_list = self._handle_getitem_users(node)
|
||||
output_arguments = []
|
||||
for idx, (meta, return_schema) in enumerate(zip(meta_val, returns)):
|
||||
if meta is None:
|
||||
assert isinstance(return_schema.real_type, torch.OptionalType)
|
||||
output_arguments.append(Argument.create(as_none=()))
|
||||
elif isinstance(meta, torch._subclasses.fake_tensor.FakeTensor):
|
||||
assert isinstance(return_schema.real_type, torch.TensorType)
|
||||
user_node = output_node_at_index(node, idx)
|
||||
name = (
|
||||
user_node.name
|
||||
if user_node is not None
|
||||
else f"{node.name}_unused_{idx}"
|
||||
)
|
||||
output_arguments.append(
|
||||
Argument.create(as_tensor=self.serialize_tensor_output(name, meta))
|
||||
)
|
||||
elif isinstance(meta, list):
|
||||
# for List[Tensor] return type
|
||||
assert isinstance(
|
||||
return_schema.real_type, torch.ListType
|
||||
) and isinstance(
|
||||
return_schema.real_type.getElementType(), torch.TensorType
|
||||
)
|
||||
user_node = output_node_at_index(node, idx)
|
||||
assert user_node is not None
|
||||
|
||||
# Then, pack the return value differently depending on what the return type is.
|
||||
if len(returns) == 1:
|
||||
return_type = returns[0].real_type
|
||||
assert isinstance(return_type, torch.ListType) and isinstance(
|
||||
return_type.getElementType(), torch.TensorType
|
||||
), "Only tensors and lists of tensors supported"
|
||||
args = []
|
||||
for i, m in enumerate(meta):
|
||||
if m is None:
|
||||
continue
|
||||
sub_user_node = output_node_at_index(user_node, i)
|
||||
assert sub_user_node is not None, f"No user found at index {i}"
|
||||
|
||||
return [Argument.create(as_tensors=arg_list)]
|
||||
else:
|
||||
assert all(
|
||||
isinstance(ret.real_type, torch.TensorType) for ret in returns
|
||||
), f"Multiple returns can only have tensor returns, got: {[ret.real_type for ret in returns]}"
|
||||
args.append(self.serialize_tensor_output(sub_user_node.name, m))
|
||||
output_arguments.append(Argument.create(as_tensors=args))
|
||||
|
||||
return [Argument.create(as_tensor=arg) for arg in arg_list]
|
||||
return output_arguments
|
||||
|
||||
def _handle_getitem_users(self, node: torch.fx.Node) -> List[TensorArgument]:
|
||||
meta_val = node.meta["val"]
|
||||
|
|
@ -1364,23 +1427,10 @@ class GraphModuleDeserializer:
|
|||
self.deserialize_multiple_outputs(serialized_node, fx_node)
|
||||
|
||||
def deserialize_multiple_outputs(self, serialized_node: Node, fx_node: torch.fx.Node) -> None:
|
||||
# Convert multiple return types to FX format.
|
||||
# In FX, each node only returns one value. So in order to represent
|
||||
# multiple return values, we have to emit a `getitem` node for each
|
||||
# return value.
|
||||
# This performs the inverse mapping of the `serialize_outputs` call in
|
||||
# serialization, see [NOTE: Multiple outputs]
|
||||
output_names = []
|
||||
if len(serialized_node.outputs) == 1:
|
||||
assert isinstance(serialized_node.outputs[0].value, list)
|
||||
assert isinstance(serialized_node.outputs[0].value[0], TensorArgument)
|
||||
output_names = [arg.name for arg in serialized_node.outputs[0].as_tensors]
|
||||
else:
|
||||
for output in serialized_node.outputs:
|
||||
assert isinstance(output.value, TensorArgument)
|
||||
output_names.append(output.as_tensor.name)
|
||||
deserialized_metadata = self.deserialize_metadata(serialized_node.metadata)
|
||||
|
||||
for idx, name in enumerate(output_names):
|
||||
def generate_getitem(meta_val, fx_node: torch.fx.Node, arg: TensorArgument, idx: int):
|
||||
name = arg.name
|
||||
individual_output = self.graph.create_node(
|
||||
"call_function",
|
||||
operator.getitem,
|
||||
|
|
@ -1388,12 +1438,46 @@ class GraphModuleDeserializer:
|
|||
name=name,
|
||||
)
|
||||
self.sync_fx_node(name, individual_output)
|
||||
meta_val.append(self.serialized_name_to_meta[name])
|
||||
# The derived `getitem` nodes should have the same stacktrace as the
|
||||
# original `fx_node`
|
||||
individual_output.meta.update(self.deserialize_metadata(serialized_node.metadata))
|
||||
individual_output.meta.update(deserialized_metadata)
|
||||
|
||||
def generate_getitems(meta_val, fx_node: torch.fx.Node, args):
|
||||
for idx, arg in enumerate(args):
|
||||
if isinstance(arg, Argument):
|
||||
arg = arg.value
|
||||
if isinstance(arg, TensorArgument):
|
||||
generate_getitem(meta_val, fx_node, arg, idx)
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
list_output = self.graph.create_node(
|
||||
"call_function",
|
||||
operator.getitem,
|
||||
(fx_node, idx),
|
||||
)
|
||||
meta_val.append([])
|
||||
generate_getitems(meta_val[-1], list_output, arg)
|
||||
list_output.meta.update(deserialized_metadata)
|
||||
list_output.meta['val'] = meta_val[-1]
|
||||
else:
|
||||
raise NotImplementedError(f"Unimplemented node output type: {arg}")
|
||||
|
||||
# Convert multiple return types to FX format.
|
||||
# In FX, each node only returns one value. So in order to represent
|
||||
# multiple return values, we have to emit a `getitem` node for each
|
||||
# return value.
|
||||
# This performs the inverse mapping of the `serialize_outputs` call in
|
||||
# serialization, see [NOTE: Multiple outputs]
|
||||
meta_val: List[Any] = []
|
||||
if len(serialized_node.outputs) == 1:
|
||||
assert isinstance(serialized_node.outputs[0].value, list)
|
||||
assert isinstance(serialized_node.outputs[0].value[0], TensorArgument)
|
||||
generate_getitems(meta_val, fx_node, serialized_node.outputs[0].as_tensors)
|
||||
else:
|
||||
generate_getitems(meta_val, fx_node, serialized_node.outputs)
|
||||
|
||||
# also update the metaval for `fx_node` to be a list(meta)
|
||||
fx_node.meta["val"] = tuple(self.serialized_name_to_meta[name] for name in output_names)
|
||||
fx_node.meta["val"] = tuple(meta_val)
|
||||
self.serialized_name_to_node[fx_node.name] = fx_node
|
||||
|
||||
def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user