pytorch/torch/_export/serde/schema.py
Zhengxu Chen c919935cb7 [export] Update schema versioning format. (#116462)
Summary: Update the old versioning scheme to a major and minor version.

Test Plan: CI

Differential Revision: D52431963

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116462
Approved by: https://github.com/tugsbayasgalan
2024-01-03 17:34:58 +00:00

328 lines
7.2 KiB
Python

# NOTE: This is a placeholder for iterating on export serialization schema design.
# Anything is subject to change and no guarantee is provided at this point.
from dataclasses import dataclass
from enum import IntEnum
from typing import Dict, List, Optional, Tuple
from torch._export.serde.union import _Union
# NOTE: Please update this value if any modifications are made to the schema
SCHEMA_VERSION = (3, 1)
TREESPEC_VERSION = 1
class ScalarType(IntEnum):
UNKNOWN = 0
BYTE = 1
CHAR = 2
SHORT = 3
INT = 4
LONG = 5
HALF = 6
FLOAT = 7
DOUBLE = 8
COMPLEXHALF = 9
COMPLEXFLOAT = 10
COMPLEXDOUBLE = 11
BOOL = 12
BFLOAT16 = 13
class Layout(IntEnum):
Unknown = 0
SparseCoo = 1
SparseCsr = 2
SparseCsc = 3
SparseBsr = 4
SparseBsc = 5
_mkldnn = 6
Strided = 7
class MemoryFormat(IntEnum):
Unknown = 0
ContiguousFormat = 1
ChannelsLast = 2
ChannelsLast3d = 3
PreserveFormat = 4
@dataclass
class Device:
type: str
index: Optional[int]
@dataclass(repr=False)
class SymExprHint(_Union):
as_int: int
as_float: float
as_bool: bool
# This is for storing the symbolic expressions behind symints/symfloats/symbools
# For example, we can get something like
# SymExpr(expr_str="s0 + s1", hint=SymExprHint(as_int=4)
# if we also have the hint that s0 and s1 are both 2.
@dataclass
class SymExpr:
expr_str: str
hint: Optional[SymExprHint] = None
@dataclass(repr=False)
class SymInt(_Union):
as_expr: SymExpr
as_int: int
@dataclass(repr=False)
class SymBool(_Union):
as_expr: SymExpr
as_bool: bool
@dataclass
class TensorMeta:
dtype: ScalarType
sizes: List[SymInt]
requires_grad: bool
device: Device
strides: List[SymInt]
storage_offset: int
layout: Layout
# In most cases we will use the "as_name" field to store arguments which are
# SymInts.
# The "as_int" field is used in the case where we have a list containing a mix
# of SymInt and ints (ex. [1, s0, ...]). We will serialize this type of list to
# be List[SymIntArgument] and map the SymInts to the "as_name" field, and ints
# to the "as_int" field.
@dataclass(repr=False)
class SymIntArgument(_Union):
as_name: str
as_int: int
# In most cases we will use the "as_name" field to store arguments which are
# SymBools.
# The "as_bool" field is used in the case where we have a list containing a mix
# of SymBool and bools (ex. [True, i0, ...]). We will serialize this type of list to
# be List[SymboolArgument] and map the SymBools to the "as_name" field, and bools
# to the "as_bool" field.
@dataclass(repr=False)
class SymBoolArgument(_Union):
as_name: str
as_bool: bool
@dataclass
class TensorArgument:
name: str
# This is use for storing the contents of a list which contain optional tensors
# (Tensor?[], ex. [Tensor, None, ...]), where the list will be serialized to the
# type List[OptionalTensorArgument], with tensor values seiralized to the
# "as_tensor" field, and None values serialized to the "as_none" field.
@dataclass(repr=False)
class OptionalTensorArgument(_Union):
as_tensor: str
as_none: Tuple[()]
@dataclass
class GraphArgument:
name: str
graph: 'Graph'
@dataclass
class CustomObjArgument:
name: str
# This is actually a union type
@dataclass(repr=False)
class Argument(_Union):
as_none: Tuple[()]
as_tensor: TensorArgument
as_tensors: List[TensorArgument]
as_int: int
as_ints: List[int]
as_float: float
as_floats: List[float]
as_string: str
as_strings: List[str]
as_sym_int: SymIntArgument
as_sym_ints: List[SymIntArgument]
as_scalar_type: ScalarType
as_memory_format: MemoryFormat
as_layout: Layout
as_device: Device
as_bool: bool
as_bools: List[bool]
as_sym_bool: SymBoolArgument
as_sym_bools: List[SymBoolArgument]
as_graph: GraphArgument
as_optional_tensors: List[OptionalTensorArgument]
as_custom_obj: CustomObjArgument
@dataclass
class NamedArgument:
# Argument name from the operator schema
name: str
arg: Argument
@dataclass
class Node:
target: str
inputs: List[NamedArgument]
outputs: List[Argument]
metadata: Dict[str, str]
@dataclass
class Graph:
inputs: List[Argument]
outputs: List[Argument]
nodes: List[Node]
tensor_values: Dict[str, TensorMeta]
sym_int_values: Dict[str, SymInt]
sym_bool_values: Dict[str, SymBool]
# This is for deserializing the submodule graphs from higher order ops
# (ex. cond, map) where single tensor returns will just return a single
# tensor, rather than following export schema and returning a singleton
# list.
is_single_tensor_return: bool = False
@dataclass
class UserInputSpec:
# Actually, only tensors and SymInts are allowed here
arg: Argument
@dataclass
class InputToParameterSpec:
arg: TensorArgument
parameter_name: str
@dataclass
class InputToBufferSpec:
arg: TensorArgument
buffer_name: str
@dataclass
class InputToTensorConstantSpec:
arg: TensorArgument
tensor_constant_name: str
@dataclass(repr=False)
class InputSpec(_Union):
user_input: UserInputSpec
parameter: InputToParameterSpec
buffer: InputToBufferSpec
tensor_constant: InputToTensorConstantSpec
@dataclass
class UserOutputSpec:
arg: Argument
@dataclass
class LossOutputSpec:
arg: TensorArgument
@dataclass
class BufferMutationSpec:
arg: TensorArgument
buffer_name: str
@dataclass
class GradientToParameterSpec:
arg: TensorArgument
parameter_name: str
@dataclass
class GradientToUserInputSpec:
arg: TensorArgument
user_input_name: str
@dataclass(repr=False)
class OutputSpec(_Union):
user_output: UserOutputSpec
loss_output: LossOutputSpec
buffer_mutation: BufferMutationSpec
gradient_to_parameter: GradientToParameterSpec
gradient_to_user_input: GradientToUserInputSpec
@dataclass
class GraphSignature:
input_specs: List[InputSpec]
output_specs: List[OutputSpec]
@dataclass
class RangeConstraint:
min_val: int
max_val: int
@dataclass
class ModuleCallSignature:
inputs: List[Argument]
outputs: List[Argument]
# These are serialized by calling pytree.treespec_loads
# And deserialized by calling pytree.treespec_dumps
in_spec: str
out_spec: str
@dataclass
class ModuleCallEntry:
fqn: str
signature: Optional[ModuleCallSignature] = None
@dataclass
class GraphModule:
graph: Graph
signature: GraphSignature
# This is used for unflattening, by tracking the calling structure of all of
# the modules in order to unflatten the modules back to the eager calling
# conventions.
module_call_graph: List[ModuleCallEntry]
# Invariant: Every time a change is made to the schema, one of the versions
# should be upadted.
@dataclass
class SchemaVersion:
major: int # Major version number is bumped every time a breaking change is made.
minor: int # Minor version number is bumped when a compatible change is made.
@dataclass
class ExportedProgram:
graph_module: GraphModule
# Key is the opset namespace (ex. aten), and value is the version number
opset_version: Dict[str, int]
range_constraints: Dict[str, RangeConstraint]
schema_version: SchemaVersion
dialect: str