mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This reverts commit e6d9350d7f.
Reverted https://github.com/pytorch/pytorch/pull/111262 on behalf of https://github.com/jeanschmidt due to Breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/111262#issuecomment-1765881675))
476 lines
16 KiB
ReStructuredText
476 lines
16 KiB
ReStructuredText
.. _export.ir_spec:
|
||
|
||
torch.export IR Specification
|
||
=============================
|
||
|
||
Export IR is an intermediate representation (IR) for compilers, which bears
|
||
similarities to MLIR and TorchScript. It is specifically designed to express the
|
||
semantics of PyTorch programs. Export IR primarily represents computation in a
|
||
streamlined list of operations, with limited support for dynamism such as
|
||
control flows.
|
||
|
||
To create an Export IR graph, a frontend can be used that soundly captures a
|
||
PyTorch program via a trace-specializing mechanism. The resulting Export IR can
|
||
then be optimized and executed by a backend. This can be done today through
|
||
:func:`torch.export.export`.
|
||
|
||
The key concepts that will be covered in this document include:
|
||
|
||
- ExportedProgram: the data structure containing the Export IR program
|
||
- Graph: which consists of a list of nodes.
|
||
- Nodes: which represents operations, control flow, and metadata stored on this node.
|
||
- Values are produced and consumed by nodes.
|
||
- Types are associated with values and nodes.
|
||
- The size and memory layout of values are also defined.
|
||
|
||
Assumptions
|
||
------------
|
||
|
||
This doc assumes that the audience is sufficiently familiar with PyTorch,
|
||
specifically with :class:`torch.fx` and its related toolings. Thus it will stop
|
||
describing contents present in :class:`torch.fx` documentation and paper.
|
||
|
||
What is Export IR
|
||
-----------------
|
||
|
||
Export IR is a graph-based intermediate representation IR of PyTorch programs.
|
||
Export IR is realized on top of :class:`torch.fx.Graph`. In other words, **all
|
||
Export IR graphs are also valid FX graphs**, and if interpreted using standard
|
||
FX semantics, Export IR can be interpreted soundly. One implication is that an
|
||
exported graph can be converted to a valid Python program via standard FX
|
||
codegen.
|
||
|
||
This documentation will primarily focus on highlighting areas where Export IR
|
||
differs from FX in terms of its strictness, while skipping parts where it shares
|
||
similarities with FX.
|
||
|
||
ExportedProgram
|
||
---------------
|
||
|
||
The top-level Export IR construct is an :class:`torch.export.ExportedProgram`
|
||
class. It bundles the computational graph of a PyTorch model (which is usually a
|
||
:class:`torch.nn.Module`) with the parameters or weights that this model
|
||
consumes.
|
||
|
||
Some notable attributes of the :class:`torch.export.ExportedProgram` class are:
|
||
|
||
- ``graph_module`` (:class:`torch.fx.GraphModule`): Data structure containing
|
||
the flattened computational graph of the PyTorch model. The graph can be
|
||
directly accessed through `ExportedProgram.graph`.
|
||
- ``graph_signature`` (:class:`torch.export.ExportGraphSignature`): The graph
|
||
signature, which specifies the parameters and buffer names used and mutated
|
||
within the graph. Instead of storing parameters and buffers as attributes of
|
||
the graph, they are lifted as inputs to the graph. The graph_signature is
|
||
utilized to keep track of additional information on these parameters and
|
||
buffers.
|
||
- ``state_dict`` (``Dict[str, Union[torch.Tensor, torch.nn.Parameter]]``): Data
|
||
structure containing the parameters and buffers.
|
||
- ``range_constraints`` (``Dict[sympy.Symbol, RangeConstraint]``): For programs
|
||
that are exported with data dependent behavior, the metadata on each node will
|
||
contain symbolic shapes (which look like ``s0``, ``i0``). This attribute maps
|
||
the symbolic shapes to their lower/upper ranges.
|
||
- ``equality_constraints`` (``List[Tuple[InputDim, InputDim]]``): A list of
|
||
nodes in the graph and dimensions that have the same shape.
|
||
|
||
Graph
|
||
-----
|
||
|
||
An Export IR Graph is a PyTorch program represented in the form of a DAG
|
||
(directed acyclic graph). Each node in this graph represents a particular
|
||
computation or operation, and edges of this graph consist of references between
|
||
nodes.
|
||
|
||
We can view Graph having this schema:
|
||
|
||
.. code-block:: python
|
||
|
||
class Graph:
|
||
nodes: List[Node]
|
||
|
||
In practice, Export IR's graph is realized as :class:`torch.fx.Graph` Python class.
|
||
|
||
An Export IR graph contains the following nodes (Nodes will be described in more
|
||
details in the next section):
|
||
|
||
- 0 or more nodes of op type ``placeholder``
|
||
- 0 or more nodes of op type ``call_function``
|
||
- exactly 1 node of op type ``output``
|
||
|
||
**Collorary:** The smallest valid Graph will be of one node. i.e. nodes is never empty.
|
||
|
||
**Definition:**
|
||
The set of ``placeholder`` nodes of a Graph represents the **inputs** of the
|
||
Graph of GraphModule. The `output` node of a Graph represents the **outputs**
|
||
of the Graph of GraphModule.
|
||
|
||
Example::
|
||
|
||
from torch import nn
|
||
|
||
class MyModule(nn.Module):
|
||
|
||
def forward(self, x, y):
|
||
return x + y
|
||
|
||
mod = torch._export.export(MyModule())
|
||
print(mod.graph)
|
||
|
||
The above is the textual representation of a Graph, with each line being a node.
|
||
|
||
Node
|
||
----
|
||
|
||
A Node represents a particular computation or operation and is represented in
|
||
Python using the :class:`torch.fx.Node` class. Edges between nodes are
|
||
represented as direct references to other nodes via the ``args`` property of the
|
||
Node class. Using the same FX machinery, we can represent the following
|
||
operations that a computational graph typically needs, such as operator calls,
|
||
placeholders (aka inputs), conditionals, and loops.
|
||
|
||
The Node has the following schema:
|
||
|
||
.. code-block:: python
|
||
|
||
class Node:
|
||
name: str # name of node
|
||
op_name: str # type of operation
|
||
|
||
# interpretation of the fields below depends on op_name
|
||
target: [str|Callable]
|
||
args: List[object]
|
||
kwargs: Dict[str, object]
|
||
meta: Dict[str, object]
|
||
|
||
**FX Text Format**
|
||
|
||
As in the example above, notice that each line has this format::
|
||
|
||
%<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5})
|
||
|
||
This format captures everything present in the Node class, with the exception of
|
||
``meta``, in a compact format.
|
||
|
||
Concretely:
|
||
|
||
- **<name>** is the name of the node as it would appear in ``node.name``.
|
||
|
||
- **<op_name>** is the ``node.op`` field, which must be one of these:
|
||
`<call_function>`, `<placeholder>`,
|
||
`<get_attr>`, or `<output>`.
|
||
|
||
- **<target>** is the target of the node as ``node.target``. The meaning of this
|
||
field depends on ``op_name``.
|
||
|
||
- **args1, … args 4…** are what is listed in the ``node.args`` tuple. If a
|
||
value in the list is an :class:`torch.fx.Node`, then it will be especially
|
||
indicated with a leading **%.**
|
||
|
||
For example, a call to the add operator would appear as::
|
||
|
||
%add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {})
|
||
|
||
Where ``%x``, ``%y`` are two other Nodes that have names x and y. Worth noting
|
||
that the string ``torch.op.aten.add.Tensor`` represents the callable object that
|
||
is actually stored in the target field, not merely its string name.
|
||
|
||
The final line of this text format is::
|
||
|
||
return [add]
|
||
|
||
which is a Node with ``op_name = output``, indicating that we are returning this
|
||
one element.
|
||
|
||
call_function
|
||
^^^^^^^^^^^^^
|
||
|
||
A ``call_function`` node represents a call to an operator.
|
||
|
||
**Definitions**
|
||
|
||
- **Functional:** We say a callable is “functional” if it satisfies all the
|
||
following requirements:
|
||
|
||
- Non-mutating: The operator does not mutate the value of its input (for
|
||
tensors, this includes both metadata and data).
|
||
- No side effects: The operator does not mutate states that are visible
|
||
from outside, like changing values of module parameters.
|
||
|
||
- **Operator:** is a functional callable with a predefined schema. Examples of
|
||
such operators include functional ATen operators.
|
||
|
||
**Representation in FX**
|
||
|
||
.. code-block::
|
||
|
||
%name = call_function[target = operator](args = (%x, %y, …), kwargs = {})
|
||
|
||
|
||
**Differences from vanilla FX call_function**
|
||
|
||
1. In FX graph, a call_function can refer to any callable, in Export IR, we
|
||
restrict it to only a select subset of ATen operators, custom operators, and
|
||
control flow operators.
|
||
|
||
2. In Export IR, constant arguments will be embedded within the graph.
|
||
|
||
3. In FX graph, a get_attr node can represent reading any attribute stored in
|
||
the graph module. However, in Export IR this is restricted to readign only
|
||
submodules as all parameters/buffers will be passed in as inputs to the graph
|
||
module.
|
||
|
||
Metadata
|
||
~~~~~~~~
|
||
|
||
``Node.meta`` is a dict attached to every FX node. However, the FX spec does not
|
||
specify what metadata can or will be there. Export IR provides a stronger
|
||
contract, specifically all ``call_function`` nodes will guarantee having and
|
||
only having the following metadata fields:
|
||
|
||
- ``node.meta["stack_trace"]`` is a string containing the Python stack trace
|
||
referencing the original Python source code. An example stack trace looks
|
||
like::
|
||
|
||
File "my_module.py", line 19, in forward
|
||
return x + dummy_helper(y)
|
||
File "helper_utility.py", line 89, in dummy_helper
|
||
return y + 1
|
||
|
||
- ``node.meta["val"]`` describes the output of running the operation. It can be
|
||
of type `<symint>`, `<FakeTensor>`, a
|
||
``List[Union[FakeTensor, SymInt]]``, or ``None``.
|
||
|
||
- ``node.meta["nn_module_stack"]`` describes the "stacktrace" of the
|
||
:class:`torch.nn.Module` from which the node came, if it was from a
|
||
:class:`torch.nn.Module` call. For example, if a node containing the ``addmm``
|
||
op called from a :class:`torch.nn.Linear` module inside of a
|
||
:class:`torch.nn.Sequential` module, the ``nn_module_stack`` would look
|
||
something like::
|
||
|
||
{'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
|
||
|
||
- ``node.meta["source_fn_stack"]`` contains the torch function or the leaf
|
||
:class:`torch.nn.Module` class this node was called from before decomposition.
|
||
For example, a node containing the ``addmm`` op from a
|
||
:class:`torch.nn.Linear` module call would contain :class:`torch.nn.Linear` in
|
||
their ``source_fn``, and a node containing the ``addmm`` op from a
|
||
:class:`torch.nn.functional.Linear` module call would contain
|
||
:class:`torch.nn.functional.Linear` in their ``source_fn``.
|
||
|
||
placeholder
|
||
^^^^^^^^^^^
|
||
|
||
Placeholder represents an input to a graph. Its semantics are exactly the same as in FX.
|
||
Placeholder nodes must be the first N nodes in the nodes list of a graph. N can be zero.
|
||
|
||
**Representation in FX**
|
||
|
||
.. code-block:: python
|
||
|
||
%name = placeholder[target = name](args = ())
|
||
|
||
The target field is a string which is the name of input.
|
||
|
||
``args``, if non-empty, should be of size 1 representing the default value of this input.
|
||
|
||
**Metadata**
|
||
|
||
Placeholder nodes also have ``meta[‘val’]``, like ``call_function`` nodes. The
|
||
``val`` field in this case represents the input shape/dtype that the graph is
|
||
expected to receive for this input parameter.
|
||
|
||
output
|
||
^^^^^^
|
||
|
||
An output call represents a return statement in a function; it thus terminates the
|
||
current graph. There is one and only one output node, and it will always be the
|
||
last node of the graph.
|
||
|
||
**Representation in FX**
|
||
|
||
.. code-block::
|
||
|
||
output[](args = (%something, …))
|
||
|
||
This has the exact semantics as in :class:`torch.fx`. ``args`` represents the node
|
||
to be returned.
|
||
|
||
**Metadata**
|
||
|
||
Output node has the same metadata as ``call_function`` nodes.
|
||
|
||
get_attr
|
||
^^^^^^^^
|
||
|
||
``get_attr`` nodes represent reading a submodule from the encapsulating
|
||
:class:`torch.fx.GraphModule`. Unlike a vanilla FX graph from
|
||
:func:`torch.fx.symbolic_trace` in which ``get_attr`` nodes are used to read
|
||
attributes such as parameters and buffers from the top-level
|
||
:class:`torch.fx.GraphModule`, parameters and buffers are passed in as
|
||
inputs to the graph module, and stored in the top-level
|
||
:class:`torch.export.ExportedProgram`.
|
||
|
||
**Representation in FX**
|
||
|
||
.. code-block:: python
|
||
|
||
%name = get_attr[target = name](args = ())
|
||
|
||
**Example**
|
||
|
||
Consider the following model::
|
||
|
||
from functorch.experimental.control_flow import cond
|
||
|
||
def true_fn(x):
|
||
return x.sin()
|
||
|
||
def false_fn(x):
|
||
return x.cos()
|
||
|
||
def f(x, y):
|
||
return cond(y, true_fn, false_fn, [x])
|
||
|
||
Graph::
|
||
|
||
graph():
|
||
%x_1 : [num_users=1] = placeholder[target=x_1]
|
||
%y_1 : [num_users=1] = placeholder[target=y_1]
|
||
%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
|
||
%false_graph_0 : [num_users=1] = get_attr[target=false_graph_0]
|
||
%conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {})
|
||
return conditional
|
||
|
||
The line, ``%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]``,
|
||
reads the submodule ``true_graph_0`` which contains the ``sin`` operator.
|
||
|
||
References
|
||
----------
|
||
|
||
SymInt
|
||
^^^^^^
|
||
|
||
A SymInt is an object that can either be a literal integer or a symbol that represents
|
||
an Integer (represented in Python by ``sympy.Symbol`` class). When SymInt is a
|
||
symbol, it describes a variable of type integer that is unknown to the graph at
|
||
compile time, that is, its value is only known at runtime.
|
||
|
||
FakeTensor
|
||
^^^^^^^^^^
|
||
|
||
A FakeTensor is an object that contains the metadata of a tensor. It can be
|
||
viewed as having the following metadata.
|
||
|
||
.. code-block:: python
|
||
|
||
class FakeTensor:
|
||
size: List[SymInt]
|
||
dtype: torch.dtype
|
||
device: torch.device
|
||
dim_order: List[int] # This doesn't exist yet
|
||
|
||
The size field of FakeTensor is a list of integers or SymInts. If SymInts are
|
||
present, this means this tensor has a dynamic shape. If integers are present, it
|
||
is assumed that the tensor will have that exact static shape. The rank of the
|
||
TensorMeta is never dynamic. The dtype field represents the dtype of the
|
||
output of that node. There are no implicit type promotions in Edge IR. There
|
||
are no strides in FakeTensor.
|
||
|
||
In other words:
|
||
|
||
- If the operator in node.target returns a Tensor, then ``node.meta['val']`` is a
|
||
FakeTensor describing that tensor.
|
||
- If the operator in node.target returns an n-tuple of Tensors, then
|
||
``node.meta['val']`` is an n-tuple of FakeTensors describing each tensor.
|
||
- If the operator in node.target returns an int/float/scalar that is known at
|
||
compile time, then ``node.meta['val']`` is None.
|
||
- If the operator in node.target returns an int/float/scalar that is not known
|
||
at compile time, then ``node.meta['val']`` is of type SymInt.
|
||
|
||
For example:
|
||
|
||
- ``aten::add`` returns a Tensor; so its spec will be a FakeTensor with dtype
|
||
and size of the tensor returned by this operator.
|
||
- ``aten::sym_size`` returns an integer; so its val will be a SymInt because its
|
||
value is only available at runtime.
|
||
- ``max_pool2d_with_indexes`` returns a tuple of (Tensor, Tensor); so the spec
|
||
will also be a 2-tuple of FakeTensor objects, the first TensorMeta describes
|
||
the first element of the return value etc.
|
||
|
||
Python code::
|
||
|
||
def add_one(x):
|
||
return torch.ops.aten(x, 1)
|
||
|
||
Graph::
|
||
|
||
graph():
|
||
%ph_0 : [#users=1] = placeholder[target=ph_0]
|
||
%add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {})
|
||
return [add_tensor]
|
||
|
||
FakeTensor::
|
||
|
||
FakeTensor(dtype=torch.int, size=[2,], device=CPU)
|
||
|
||
Pytree-able Types
|
||
^^^^^^^^^^^^^^^^^
|
||
|
||
We define a type “Pytree-able”, if it is either a leaf type or a container type
|
||
that contains other Pytree-able types.
|
||
|
||
Note:
|
||
|
||
The concept of pytree is the same as the one documented
|
||
`here <https://jax.readthedocs.io/en/latest/pytrees.html>`__ for JAX:
|
||
|
||
|
||
The following types are defined as **leaf type**:
|
||
|
||
.. list-table::
|
||
:widths: 50 50
|
||
:header-rows: 1
|
||
|
||
* - Type
|
||
- Definition
|
||
* - Tensor
|
||
- :class:`torch.Tensor`
|
||
* - Scalar
|
||
- Any numerical types from Python, including integral types, floating point types, and zero dimensional tensors.
|
||
* - int
|
||
- Python int (binded as int64_t in C++)
|
||
* - float
|
||
- Python float (binded as double in C++)
|
||
* - bool
|
||
- Python bool
|
||
* - str
|
||
- Python string
|
||
* - ScalarType
|
||
- :class:`torch.dtype`
|
||
* - Layout
|
||
- :class:`torch.layout`
|
||
* - MemoryFormat
|
||
- :class:`torch.memory_format`
|
||
* - Device
|
||
- :class:`torch.device`
|
||
|
||
The following types are defined as **container type**:
|
||
|
||
.. list-table::
|
||
:widths: 50 50
|
||
:header-rows: 1
|
||
|
||
* - Type
|
||
- Definition
|
||
* - Tuple
|
||
- Python tuple
|
||
* - List
|
||
- Python list
|
||
* - Dict
|
||
- Python dict with Scalar keys
|
||
* - NamedTuple
|
||
- Python namedtuple
|
||
* - Dataclass
|
||
- Must be registered through `register_dataclass <https://github.com/pytorch/pytorch/blob/901aa85b58e8f490631ce1db44e6555869a31893/torch/export/__init__.py#L693>`__
|
||
* - Custom class
|
||
- Any custom class defined with `_register_pytree_node <https://github.com/pytorch/pytorch/blob/901aa85b58e8f490631ce1db44e6555869a31893/torch/utils/_pytree.py#L72>`__
|