mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary: When a C++ custom op returns an uninitialized tensor, it will be marked as None in Python. For this scenario, the user should mark the possibly uninitialized return as Tensor? in the custom op schema. This diff adds `as_optional_tensor` type to export schema and the support for optional tensor in AOTI proxy executor. Test Plan: ``` buck2 run mode/dev-nosan caffe2/test/inductor:test_aot_inductor_custom_ops -- -r test_fn_with_optional_tensor_output ``` Differential Revision: D75262529 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154286 Approved by: https://github.com/desertfire
506 lines
14 KiB
Python
506 lines
14 KiB
Python
# NOTE: This is a placeholder for iterating on export serialization schema design.
|
|
# Anything is subject to change and no guarantee is provided at this point.
|
|
|
|
from dataclasses import dataclass, field
|
|
from enum import IntEnum
|
|
from typing import Annotated, Optional
|
|
|
|
from torch._export.serde.union import _Union
|
|
|
|
|
|
# NOTE: Please update this value if any modifications are made to the schema
|
|
SCHEMA_VERSION = (8, 8)
|
|
TREESPEC_VERSION = 1
|
|
|
|
|
|
# NOTE: If you updated the schema, please run `scripts/export/update_schema.py`
|
|
# to update the auto generated files.
|
|
class ScalarType(IntEnum):
|
|
UNKNOWN = 0
|
|
BYTE = 1
|
|
CHAR = 2
|
|
SHORT = 3
|
|
INT = 4
|
|
LONG = 5
|
|
HALF = 6
|
|
FLOAT = 7
|
|
DOUBLE = 8
|
|
COMPLEXHALF = 9
|
|
COMPLEXFLOAT = 10
|
|
COMPLEXDOUBLE = 11
|
|
BOOL = 12
|
|
BFLOAT16 = 13
|
|
UINT16 = 28
|
|
FLOAT8E4M3FN = 29
|
|
FLOAT8E5M2 = 30
|
|
|
|
|
|
class Layout(IntEnum):
|
|
Unknown = 0
|
|
SparseCoo = 1
|
|
SparseCsr = 2
|
|
SparseCsc = 3
|
|
SparseBsr = 4
|
|
SparseBsc = 5
|
|
_mkldnn = 6
|
|
Strided = 7
|
|
|
|
|
|
class MemoryFormat(IntEnum):
|
|
Unknown = 0
|
|
ContiguousFormat = 1
|
|
ChannelsLast = 2
|
|
ChannelsLast3d = 3
|
|
PreserveFormat = 4
|
|
|
|
|
|
@dataclass
|
|
class Device:
|
|
type: Annotated[str, 10]
|
|
index: Annotated[Optional[int], 20] = None
|
|
|
|
|
|
@dataclass(repr=False)
|
|
class SymExprHint(_Union):
|
|
as_int: Annotated[int, 10]
|
|
as_bool: Annotated[bool, 20]
|
|
as_float: Annotated[float, 30]
|
|
|
|
|
|
# This is for storing the symbolic expressions behind symints/symfloats/symbools
|
|
# For example, we can get something like
|
|
# SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4)
|
|
# if we also have the hint that s0 and s1 are both 2.
|
|
@dataclass
|
|
class SymExpr:
|
|
expr_str: Annotated[str, 10]
|
|
hint: Annotated[Optional[SymExprHint], 20] = None
|
|
|
|
|
|
@dataclass(repr=False)
|
|
class SymInt(_Union):
|
|
as_expr: Annotated[SymExpr, 10]
|
|
as_int: Annotated[int, 20]
|
|
|
|
|
|
@dataclass(repr=False)
|
|
class SymFloat(_Union):
|
|
as_expr: Annotated[SymExpr, 10]
|
|
as_float: Annotated[float, 20]
|
|
|
|
|
|
@dataclass(repr=False)
|
|
class SymBool(_Union):
|
|
as_expr: Annotated[SymExpr, 10]
|
|
as_bool: Annotated[bool, 20]
|
|
|
|
|
|
@dataclass
|
|
class TensorMeta:
|
|
dtype: Annotated[ScalarType, 10]
|
|
sizes: Annotated[list[SymInt], 20]
|
|
requires_grad: Annotated[bool, 30]
|
|
device: Annotated[Device, 40]
|
|
strides: Annotated[list[SymInt], 50]
|
|
storage_offset: Annotated[SymInt, 60]
|
|
layout: Annotated[Layout, 70]
|
|
|
|
|
|
# In most cases we will use the "as_name" field to store arguments which are
|
|
# SymInts.
|
|
# The "as_int" field is used in the case where we have a list containing a mix
|
|
# of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to
|
|
# be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints
|
|
# to the "as_int" field.
|
|
@dataclass(repr=False)
|
|
class SymIntArgument(_Union):
|
|
as_name: Annotated[str, 10]
|
|
as_int: Annotated[int, 20]
|
|
|
|
|
|
# In most cases we will use the "as_name" field to store arguments which are
|
|
# SymFloats.
|
|
# The "as_float" field is used in the case where we have a list containing a mix
|
|
# of SymFloat and float (ex. [1.0, s0, ...]). We will serialize this type of list to
|
|
# be List[SymFloatArgument] and map the SymFloats to the "as_name" field, and ints
|
|
# to the "as_float" field.
|
|
@dataclass(repr=False)
|
|
class SymFloatArgument(_Union):
|
|
as_name: Annotated[str, 10]
|
|
as_float: Annotated[float, 20]
|
|
|
|
|
|
# In most cases we will use the "as_name" field to store arguments which are
|
|
# SymBools.
|
|
# The "as_bool" field is used in the case where we have a list containing a mix
|
|
# of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to
|
|
# be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools
|
|
# to the "as_bool" field.
|
|
@dataclass(repr=False)
|
|
class SymBoolArgument(_Union):
|
|
as_name: Annotated[str, 10]
|
|
as_bool: Annotated[bool, 20]
|
|
|
|
|
|
@dataclass
|
|
class TensorArgument:
|
|
name: Annotated[str, 10]
|
|
|
|
|
|
@dataclass
|
|
class TokenArgument:
|
|
name: Annotated[str, 10]
|
|
|
|
|
|
# This is use for storing the contents of a list which contain optional tensors
|
|
# (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the
|
|
# type List[OptionalTensorArgument], with tensor values seiralized to the
|
|
# "as_tensor" field, and None values serialized to the "as_none" field.
|
|
@dataclass(repr=False)
|
|
class OptionalTensorArgument(_Union):
|
|
as_tensor: Annotated[TensorArgument, 20]
|
|
as_none: Annotated[bool, 10]
|
|
|
|
|
|
@dataclass
|
|
class GraphArgument:
|
|
name: Annotated[str, 10]
|
|
graph: Annotated["Graph", 20]
|
|
|
|
|
|
@dataclass
|
|
class CustomObjArgument:
|
|
name: Annotated[str, 10]
|
|
class_fqn: Annotated[str, 20]
|
|
|
|
|
|
# This is actually a union type
|
|
@dataclass(repr=False)
|
|
class Argument(_Union):
|
|
as_none: Annotated[bool, 10]
|
|
as_tensor: Annotated[TensorArgument, 20]
|
|
as_tensors: Annotated[list[TensorArgument], 30]
|
|
as_int: Annotated[int, 50]
|
|
as_ints: Annotated[list[int], 70]
|
|
as_float: Annotated[float, 80]
|
|
as_floats: Annotated[list[float], 90]
|
|
as_string: Annotated[str, 100]
|
|
as_strings: Annotated[list[str], 101]
|
|
as_sym_int: Annotated[SymIntArgument, 110]
|
|
as_sym_ints: Annotated[list[SymIntArgument], 120]
|
|
as_scalar_type: Annotated[ScalarType, 130]
|
|
as_memory_format: Annotated[MemoryFormat, 140]
|
|
as_layout: Annotated[Layout, 150]
|
|
as_device: Annotated[Device, 160]
|
|
as_bool: Annotated[bool, 170]
|
|
as_bools: Annotated[list[bool], 180]
|
|
as_sym_bool: Annotated[SymBoolArgument, 182]
|
|
as_sym_bools: Annotated[list[SymBoolArgument], 184]
|
|
as_graph: Annotated[GraphArgument, 200]
|
|
as_optional_tensors: Annotated[list[OptionalTensorArgument], 190]
|
|
as_custom_obj: Annotated[CustomObjArgument, 210]
|
|
as_operator: Annotated[str, 220]
|
|
as_sym_float: Annotated[SymFloatArgument, 230]
|
|
as_sym_floats: Annotated[list[SymFloatArgument], 240]
|
|
as_optional_tensor: Annotated[OptionalTensorArgument, 250]
|
|
|
|
|
|
class ArgumentKind(IntEnum):
|
|
UNKNOWN = 0
|
|
POSITIONAL = 1
|
|
KEYWORD = 2
|
|
|
|
|
|
@dataclass
|
|
class NamedArgument:
|
|
# Argument name from the operator schema
|
|
name: Annotated[str, 10]
|
|
arg: Annotated[Argument, 20]
|
|
kind: Annotated[Optional[ArgumentKind], 30] = None
|
|
|
|
|
|
@dataclass
|
|
class Node:
|
|
target: Annotated[str, 10]
|
|
inputs: Annotated[list[NamedArgument], 20]
|
|
outputs: Annotated[list[Argument], 30]
|
|
metadata: Annotated[dict[str, str], 40]
|
|
is_hop_single_tensor_return: Annotated[Optional[bool], 50] = None
|
|
|
|
|
|
@dataclass
|
|
class Graph:
|
|
inputs: Annotated[list[Argument], 10]
|
|
outputs: Annotated[list[Argument], 20]
|
|
nodes: Annotated[list[Node], 30]
|
|
tensor_values: Annotated[dict[str, TensorMeta], 40]
|
|
sym_int_values: Annotated[dict[str, SymInt], 50]
|
|
sym_bool_values: Annotated[dict[str, SymBool], 60]
|
|
# This is for deserializing the submodule graphs from higher order ops
|
|
# (ex. cond, map) where single tensor returns will just return a single
|
|
# tensor, rather than following export schema and returning a singleton
|
|
# list.
|
|
is_single_tensor_return: Annotated[bool, 70] = False
|
|
custom_obj_values: Annotated[dict[str, CustomObjArgument], 80] = field(
|
|
default_factory=dict
|
|
)
|
|
sym_float_values: Annotated[dict[str, SymFloat], 90] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class UserInputSpec:
|
|
# Actually, only tensors and SymInts are allowed here
|
|
arg: Annotated[Argument, 10]
|
|
|
|
|
|
@dataclass(repr=False)
|
|
class ConstantValue(_Union):
|
|
as_none: Annotated[bool, 10]
|
|
as_int: Annotated[int, 20]
|
|
as_float: Annotated[float, 30]
|
|
as_string: Annotated[str, 40]
|
|
as_bool: Annotated[bool, 50]
|
|
|
|
|
|
@dataclass
|
|
class InputToConstantInputSpec:
|
|
name: Annotated[str, 10]
|
|
value: Annotated[ConstantValue, 20]
|
|
|
|
|
|
@dataclass
|
|
class InputToParameterSpec:
|
|
arg: Annotated[TensorArgument, 10]
|
|
parameter_name: Annotated[str, 20]
|
|
|
|
|
|
@dataclass
|
|
class InputToBufferSpec:
|
|
arg: Annotated[TensorArgument, 10]
|
|
buffer_name: Annotated[str, 20]
|
|
persistent: Annotated[bool, 30]
|
|
|
|
|
|
@dataclass
|
|
class InputToTensorConstantSpec:
|
|
arg: Annotated[TensorArgument, 10]
|
|
tensor_constant_name: Annotated[str, 20]
|
|
|
|
|
|
@dataclass
|
|
class InputToCustomObjSpec:
|
|
arg: Annotated[CustomObjArgument, 10]
|
|
custom_obj_name: Annotated[str, 20]
|
|
|
|
|
|
@dataclass
|
|
class InputTokenSpec:
|
|
arg: Annotated[TokenArgument, 10]
|
|
|
|
|
|
@dataclass(repr=False)
|
|
class InputSpec(_Union):
|
|
user_input: Annotated[UserInputSpec, 10]
|
|
parameter: Annotated[InputToParameterSpec, 20]
|
|
buffer: Annotated[InputToBufferSpec, 30]
|
|
tensor_constant: Annotated[InputToTensorConstantSpec, 40]
|
|
custom_obj: Annotated[InputToCustomObjSpec, 50]
|
|
token: Annotated[InputTokenSpec, 70]
|
|
constant_input: Annotated[InputToConstantInputSpec, 60]
|
|
|
|
|
|
@dataclass
|
|
class UserOutputSpec:
|
|
arg: Annotated[Argument, 10]
|
|
|
|
|
|
@dataclass
|
|
class LossOutputSpec:
|
|
arg: Annotated[TensorArgument, 10]
|
|
|
|
|
|
@dataclass
|
|
class BufferMutationSpec:
|
|
arg: Annotated[TensorArgument, 10]
|
|
buffer_name: Annotated[str, 20]
|
|
|
|
|
|
@dataclass
|
|
class GradientToParameterSpec:
|
|
arg: Annotated[TensorArgument, 10]
|
|
parameter_name: Annotated[str, 20]
|
|
|
|
|
|
@dataclass
|
|
class GradientToUserInputSpec:
|
|
arg: Annotated[TensorArgument, 10]
|
|
user_input_name: Annotated[str, 20]
|
|
|
|
|
|
@dataclass
|
|
class UserInputMutationSpec:
|
|
arg: Annotated[TensorArgument, 10]
|
|
user_input_name: Annotated[str, 20]
|
|
|
|
|
|
@dataclass
|
|
class OutputTokenSpec:
|
|
arg: Annotated[TokenArgument, 10]
|
|
|
|
|
|
@dataclass(repr=False)
|
|
class OutputSpec(_Union):
|
|
user_output: Annotated[UserOutputSpec, 10]
|
|
loss_output: Annotated[LossOutputSpec, 20]
|
|
buffer_mutation: Annotated[BufferMutationSpec, 30]
|
|
gradient_to_parameter: Annotated[GradientToParameterSpec, 40]
|
|
gradient_to_user_input: Annotated[GradientToUserInputSpec, 50]
|
|
user_input_mutation: Annotated[UserInputMutationSpec, 60]
|
|
token: Annotated[OutputTokenSpec, 70]
|
|
|
|
|
|
@dataclass
|
|
class GraphSignature:
|
|
input_specs: Annotated[list[InputSpec], 10]
|
|
output_specs: Annotated[list[OutputSpec], 20]
|
|
|
|
|
|
@dataclass
|
|
class RangeConstraint:
|
|
min_val: Annotated[Optional[int], 10]
|
|
max_val: Annotated[Optional[int], 20]
|
|
|
|
|
|
@dataclass
|
|
class ModuleCallSignature:
|
|
inputs: Annotated[list[Argument], 10]
|
|
outputs: Annotated[list[Argument], 20]
|
|
|
|
# These are serialized by calling pytree.treespec_loads
|
|
# And deserialized by calling pytree.treespec_dumps
|
|
in_spec: Annotated[str, 30]
|
|
out_spec: Annotated[str, 40]
|
|
|
|
# This field is used to prettify the graph placeholders
|
|
# after we ser/der and retrace
|
|
forward_arg_names: Annotated[Optional[list[str]], 50] = None
|
|
|
|
|
|
@dataclass
|
|
class ModuleCallEntry:
|
|
fqn: Annotated[str, 10]
|
|
signature: Annotated[Optional[ModuleCallSignature], 30] = None
|
|
|
|
|
|
@dataclass
|
|
class NamedTupleDef:
|
|
field_names: Annotated[list[str], 10]
|
|
|
|
|
|
@dataclass
|
|
class GraphModule:
|
|
graph: Annotated[Graph, 10]
|
|
signature: Annotated[GraphSignature, 50]
|
|
# This is used for unflattening, by tracking the calling structure of all of
|
|
# the modules in order to unflatten the modules back to the eager calling
|
|
# conventions.
|
|
module_call_graph: Annotated[list[ModuleCallEntry], 60]
|
|
metadata: Annotated[dict[str, str], 40] = field(default_factory=dict)
|
|
# Mapping of namedtuple types to namedtuple field names, used for BC
|
|
treespec_namedtuple_fields: Annotated[dict[str, NamedTupleDef], 70] = field(
|
|
default_factory=dict
|
|
)
|
|
|
|
|
|
# Invariant: Every time a change is made to the schema, one of the versions
|
|
# should be upadted.
|
|
@dataclass
|
|
class SchemaVersion:
|
|
major: Annotated[
|
|
int, 10
|
|
] # Major version number is bumped every time a breaking change is made.
|
|
minor: Annotated[
|
|
int, 20
|
|
] # Minor version number is bumped when a compatible change is made.
|
|
|
|
|
|
@dataclass
|
|
class ExportedProgram:
|
|
graph_module: Annotated[GraphModule, 10]
|
|
# Key is the opset namespace (ex. aten), and value is the version number
|
|
opset_version: Annotated[dict[str, int], 20]
|
|
range_constraints: Annotated[dict[str, RangeConstraint], 30]
|
|
schema_version: Annotated[SchemaVersion, 60]
|
|
verifiers: Annotated[list[str], 70] = field(default_factory=list)
|
|
torch_version: Annotated[str, 80] = "<=2.4"
|
|
|
|
|
|
#########################################################################
|
|
# Container types for inference tasks, not being used directly for export.
|
|
#########################################################################
|
|
|
|
|
|
@dataclass
|
|
class Program:
|
|
methods: Annotated[dict[str, ExportedProgram], 200]
|
|
|
|
|
|
# This is the top-level model definition that be will serialized into the package
|
|
@dataclass
|
|
class Model:
|
|
# unique identifier of the model in the package, e.g. local, remote, merge
|
|
name: Annotated[str, 10]
|
|
# key is the FQN of tensor in exported program
|
|
# value is the archive path of tensor payloads
|
|
# e.g. "L__self__linear.weight" : "/data/tensor/L__self__linear.weight"
|
|
tensorPaths: Annotated[dict[str, str], 20]
|
|
# program exported from torch.export()
|
|
program: Annotated[Program, 40]
|
|
# Backend-specialized Lowered GraphModule
|
|
# e.g. "aotinductor-a100" : ExportedProgram_with_AOTInductor_delegate
|
|
delegates: Annotated[dict[str, Program], 50]
|
|
deviceAllocationMap: Annotated[dict[str, str], 60]
|
|
# key is the FQN of constant in exported program (constant tensor or torchbind objs)
|
|
# value is the archive path of serialized constants
|
|
constantPaths: Annotated[dict[str, str], 70]
|
|
|
|
|
|
#
|
|
# The structure is used to serialize instances of AOTInductorModel to pass
|
|
# them from the publishing pipeline to the predictor.
|
|
#
|
|
# All new fields should be marked as optional.
|
|
#
|
|
@dataclass
|
|
class AOTInductorModelPickleData:
|
|
# Base name of an associated .so AOTInductor library. Typically looks like:
|
|
# "abc.so".
|
|
library_basename: Annotated[str, 1]
|
|
|
|
# AOTInductor engine input names.
|
|
input_names: Annotated[list[str], 2]
|
|
|
|
# AOTInductor engine output names.
|
|
output_names: Annotated[list[str], 3]
|
|
|
|
# These fields tell whether floating point inputs/outputs should be converted to
|
|
# a certain type. If None, the dtypes that the AOTInductor engine inferred from the sample
|
|
# inputs are used.
|
|
floating_point_input_dtype: Annotated[Optional[int], 4] = None
|
|
floating_point_output_dtype: Annotated[Optional[int], 5] = None
|
|
|
|
# Whether AOTInductor runtime is for CPU.
|
|
aot_inductor_model_is_cpu: Annotated[Optional[bool], 6] = None
|
|
|
|
|
|
@dataclass
|
|
class ExternKernelNode:
|
|
# name is not the unique identifier of the node
|
|
name: Annotated[str, 10]
|
|
node: Annotated[Node, 20]
|
|
|
|
|
|
@dataclass
|
|
class ExternKernelNodes:
|
|
nodes: Annotated[list[ExternKernelNode], 10]
|