Expose torch.export() API (#106904)

Other class definitions and utilities will be moved in subsequent PRs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106904
Approved by: https://github.com/avikchaudhuri
This commit is contained in:
gmagogsfm 2023-08-16 10:47:23 +00:00 committed by PyTorch MergeBot
parent 528a2c0aa9
commit ddba7a5a55
5 changed files with 229 additions and 3 deletions

View File

@ -1,6 +1,11 @@
torch._export.export
torch.export
=====================
.. TODO: Add torch.export() tutorial here.
.. automodule:: torch
.. autofunction:: export
.. warning::
This feature is a prototype and may have compatibility breaking changes in the future.

View File

@ -68,7 +68,7 @@ Features described in this documentation are classified by release status:
cuda
mps
torch.backends <backends>
export
torch.export <export>
torch.distributed <distributed>
torch.distributed.algorithms.join <distributed.algorithms.join>
torch.distributed.elastic <distributed.elastic>

View File

@ -706,6 +706,18 @@ Symbolic Numbers
sym_min
sym_not
Export Path
-------------
.. autosummary::
:toctree: generated
:nosignatures:
.. warning::
This feature is a prototype and may have compatibility breaking changes in the future.
export
generated/exportdb/index
Optimizations
-------------
.. autosummary::

View File

@ -33,7 +33,7 @@ if _running_with_deploy():
else:
from .torch_version import __version__ as __version__
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TYPE_CHECKING, Union, List
import builtins
__all__ = [
@ -55,6 +55,7 @@ __all__ = [
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
'SymBool', 'sym_not',
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap',
'export',
]
################################################################################
@ -1765,6 +1766,212 @@ if not _running_with_deploy():
return cls.ops_table[(op_key, dispatch_key)]
################################################################################
# Exposing torch.export() and related support classes
################################################################################
# TODO(ycao): Move ExportedProgram, Constraint, dynamic_dim etc to torch.compiler namespace
def export(
f: Callable,
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
*,
constraints: Optional[List["torch._dynamo.eval_frame.Constraint"]] = None,
) -> "torch._export.exported_program.ExportedProgram": # type: ignore[name-defined]
"""
`torch.export()` is a one-shot process for capturing a computation graph from
a PyTorch program Ahead-of-Time (AOT).
This function traces a callable (an nn.Module, a function or a method)
containing PyTorch operations and produces an ExportedProgram. The
ExportedProgram includes PyTorch operations that perform computations
equivalent to those in the given nn.Module or callable.
In specific terms, `torch.export()` traces a function `f` by executing it
with the provided `args` and `kwargs`. It records the PyTorch operations
invoked during execution to produce the ExportedProgram.
**Acceptable input/output types**
Acceptable types of inputs (for `args` and `kwargs`) and outputs include:
- Primitive types, i.e. `torch.Tensor`, `int`, `float`, `bool` and `str`.
- Dataclasses (must be registered with
torch._export.utils.register_dataclass_as_pytree_node` first)
- (Nested) Data structures comprising of `dict`, `list`, `tuple`, `namedtuple` and `OrderedDict`
containing all above types.
**What's specialized in the program?**
1. Non-tensor inputs
`torch.export()` specializes the traced program based on the values of
inputs that are not torch.Tensors, ie. `int`, `float`, `bool` and `str`.
For example::
def fn(x: torch.Tensor, i: int):
return x + i
example_inputs = (torch.rand(2, 2), 1) # i is set to 1 in example inputs
ep = torch.export(fn, example_inputs)
would yield an `ExportedProgram` containing following graph::
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=0] = placeholder[target=arg1_1]
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 1), kwargs = {})
return (add,)
Notice that `%add` is computed by adding `%arg0_1` and `1`, which is a
constant rather than `%arg1_1` because integers are specialized.
2. Rank and static shapes (not values) of input Tensors
Rank of a tensor is always specialized and treated as constant. Sizes of
dimensions are also specialized as constant, i.e. static shapes unless
specified as dynamic via `dynamic_dim` API, for example::
def fn(x):
if x.shape[0] > 5:
return x + 1
else:
return x
example_inputs = (torch.rand(10, 2))
ep = torch.export(fn, example_inputs)
Would produce an `ExportedProgram` containing following graph::
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 1), kwargs = {})
return (add,)
You can see that the conditional on `x.shape[0]>5` is removed because the
example inputs has the static shape of `(10, 2)`. `torch.export()` specializes
on the static shape, thus the `else` branch will never be reached, thus it
does not show up in the exported program.
Note:
If you want to preserve dynamic branching behavior based on value or
shape of torch.Tensor in the generated graph, you will need to use
`torch._export.dynamic_dim` to make a dimension of input tensor to be dynamic
and rewrite the source code using control flow operations like
`torch.ops.higher_order.cond`.
3. Control flow
By default, control flow (like `if`) branching decisions are spcialized
according to execution flow observed during tracing run. See following
section on how to preserve dynamic control flow
**How to express Dynamism**
1. Shape Dynamism
Because static shape use cases are more dominant, `torch.export()` chooses to
assume shapes are all static by default unless there are explicit user
instructions that say otherwise. Specifically, users can use
`torch._export.dynamic_dim` to give a hint to `torch.export()` about dynamism
and range of an input tensor dimension.
2. Dynamic Control Flow
To preserve dynamic branching behavior of control flows (like `if`), users
need to rewrite source code of original program to use PyTorch's higher order
operators (like `torch.ops.higher_order.cond`).
**Soundness Guarantee**
While tracing, `torch.export()` takes note of shape-related assumptions
made by the user program and the underlying PyTorch operator kernels.
The output ExportedProgram is considered valid only when these
assumptions hold true.
There are 2 types of assumptions made during tracing
- Shapes (not values) of input tensors.
- Ranges (lower and upper bound) of values extracted from intermediate
tensors via `.item()` or direct indexing.
All assumptions must be validated at graph capture time for `torch.export()`
to succeed. Specifically:
- Assumptions on static shapes of input tensors are automatically validated
without additional effort.
- Assumptions on dynamic shape of input tensors require explicit `Input Constraint`
constructed with `torch._export.dynamic_dim` APIs
- Assumptions on range of intermediate values require explicit `Inline Constraint`,
constructed use `constrain_as_size` and `constraint_as_value` APIs.
If any assumption can not be validated, a fatal error will be raised. When that happens,
the error message will include suggested code needed to construct necessary
constraints to validate the assumptions, for example `torch.export()` would suggest
following code for input constraints::
def specify_constraints(x):
return [
# x:
dynamic_dim(x, 0),
dynamic_dim(x, 0) <= 5,
]
This example means the program requires the dim 0 of input `x` to be less
than or equal to 5 to be valid. You can inspect the constraints needed and
then copy this exact function into your code to generated needed
constraints to be passed into `constraints` argument.
**ExportedProgram Invariants**
The returned `ExportedProgram` maintains the following invariants:
- It is guaranteed to be a sound representation of the original
program.
- It maintains the exact calling convention of the original program.
- It contains a `state_dict` that stores the `torch.nn.Parameters`
involved in computation of the original program.
- It includes an fx.GraphModule that represents the computation of
the original program. The GraphModule:
- Contains only `placeholder`, `call_function`, `get_attr` and `return` nodes.
- Inlines all submodules from the original programs.
- Lifts all parameters and buffers of the original program as
inputs to the graph.
- Does not mutate intermediate values, parameters, or buffers.
- Does not include operations with side effects.
- Contains only a curated subset of ATen operations and registered
custom operations (by default). See the list of Core ATen Ops
here: https://pytorch.org/docs/stable/ir.html
Args:
f: The callable to trace.
args: Example positional inputs.
kwargs: Optional example keyword inputs.
constraints: An optional list of constraints on the dynamic arguments
that specify their possible range of shapes. By default, shapes of
input torch.Tensors are assumed to be static. If an input torch.Tensor
is expected to have dynamic shapes, please use `torch._export.dynamic_dim()`
to define `Constraint` objects that specify the dynamics and the possible
range of shapes. See torch._export.dynamic_dim() docstring for examples on
how to use it.
Returns:
An ExportedProgram containing the traced callable.
"""
from torch._export import export
return export(f, args, kwargs, constraints)
# Deprecated attributes
_deprecated_attrs = {
"has_mps": torch.backends.mps.is_built,
@ -1784,6 +1991,7 @@ if TYPE_CHECKING:
_lazy_modules = {
"_dynamo",
"_inductor",
"_export",
# ONNX must be imported after _dynamo, _ops, _subclasses, fx, func and jit
"onnx",
}

View File

@ -147,6 +147,7 @@ def get_ignored_functions() -> Set[Callable]:
torch.empty_permuted,
torch.empty_strided,
torch.empty_quantized,
torch.export,
torch.eye,
torch.fft.fftfreq,
torch.fft.rfftfreq,