Beef up the allow_in_graph docs (#127117)

We make the following changes:
- most of the time when someone uses allow_in_graph, they actually
  wanted to make a custom op. We add a link to the custom ops landing
  page and explain the differences between allow_in_graph and custom
  ops.
- we warn people against using allow_in_graph footguns and document
  them.

Test Plan:
- tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127117
Approved by: https://github.com/jansel, https://github.com/albanD
This commit is contained in:
rzou 2024-05-31 12:18:50 -07:00 committed by PyTorch MergeBot
parent e24a87ed8d
commit 08653fe355
2 changed files with 71 additions and 26 deletions

View File

@ -74,22 +74,12 @@ def assume_constant_result(fn):
def allow_in_graph(fn):
"""
Customize which functions TorchDynamo will include in the generated
graph. Similar to `torch.fx.wrap()`.
::
Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function
and instead directly write it to the graph when encountered.
torch._dynamo.allow_in_graph(my_custom_function)
See :func:`torch.compiler.allow_in_graph`'s docstring for the full documentation
@torch._dynamo.optimize(...)
def fn(a):
x = torch.add(x, 1)
x = my_custom_function(x)
x = torch.add(x, 1)
return x
fn(...)
Will capture a single graph containing `my_custom_function()`.
WARNING: this API can be a footgun, please read the documentation carefully.
"""
if isinstance(fn, (list, tuple)):
return [allow_in_graph(x) for x in fn]

View File

@ -32,22 +32,77 @@ def reset() -> None:
def allow_in_graph(fn):
"""
Customize which functions compilation will include in the generated graph.
It bypasses all introspection of the symbolic python code in favor of
directly writing it to the graph.
If fn is a list or tuple of callables it recursively applies :func:`allow_in_graph()`
to each function and returns a new list or tuple containing the modified functions
Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function
and instead directly write it to the graph when encountered.
Args:
fn: A callable representing the function to be included in the graph.
If you are using :func:`torch.compile` (with backend="inductor" (the default)), or
:func:`torch.export.export`, and trying to black-box a Python function throughout
all tracing, do not use this API.
Instead, please create a custom operator (see :ref:`custom-ops-landing-page`)
.. warning::
:func:`allow_in_graph` skips TorchDynamo completely on the decorated function
skipping all TorchDynamo safety checks (graph breaks, handling closures, etc).
Therefore, one has to be very careful with :func:`allow_in_graph` since subsystems
like AOT Autograd rely on torchdynamo
If not careful, this could lead to soundness and really hard-to-debug issues.
If you're a typical torch.compile user (e.g. you're applying torch.compile to
a model to make it run faster), you probably don't want to use this function.
:func:`allow_in_graph` is a footgun because it skips the compiler frontend
(Dynamo) that is responsible for doing safety checks (graph breaks, handling
closures, etc). Incorrect usage will lead to difficult-to-debug silent
incorrectness issues.
Given a Python function with no allow_in_graph decorator, regular execution
of torch.compile traces through the function. :func:`allow_in_graph` changes
it so that the frontend does not trace inside the function, but the compiler
backend still traces through it. Compare this to custom operators, which
treats a function as a black box throughout the torch.compile stack. The following
table compares these mechanisms.
+------------------------+-----------------------+--------------------------------+
| Mechanism | Frontend (Dynamo) | Backend (AOTAutograd+Inductor) |
+========================+=======================+================================+
| no decorator | trace inside | trace inside |
+------------------------+-----------------------+--------------------------------+
| allow_in_graph | opaque callable | trace inside |
+------------------------+-----------------------+--------------------------------+
| custom op | opaque callable | opaque callable |
+------------------------+-----------------------+--------------------------------+
One common use case for :func:`allow_in_graph()` is as an escape hatch for the compiler
frontend: if you know the function works w.r.t. to the downstream components of the
compilation stack (AOTAutograd and Inductor) but there is a Dynamo bug that prevents it from
symbolically introspecting the function properly (or if your code is in C/C++ and
therefore cannot be introspected with Dynamo), then one can decorate said function
with :func:`allow_in_graph` to bypass Dynamo.
We require that ``fn`` adhere to the following restrictions. Failure to adhere
results in undefined behavior:
- The inputs to ``fn`` must be Proxy-able types in the FX graph. Valid types include:
Tensor/int/bool/float/None/List[Tensor?]/List[int?]/List[float?]
Tuple[Tensor?, ...]/Tuple[int?, ...]/Tuple[float?, ...]/torch.dtype/torch.device
- The outputs to ``fn`` must be Proxy-able types in the FX graph (see previous bullet)
- all Tensors used inside of ``fn`` must be passed directly as inputs to ``fn``
(as opposed to being captured variables).
Args:
fn: A callable representing the function to be included in the graph.
If ``fn`` is a list or tuple of callables it recursively applies
:func:`allow_in_graph()` to each function and returns a new list or
tuple containing the modified functions.
Example::
torch.compiler.allow_in_graph(my_custom_function)
@torch.compile(...)
def fn(a):
x = torch.add(x, 1)
x = my_custom_function(x)
x = torch.add(x, 1)
return x
fn(...)
Will capture a single graph containing ``my_custom_function()``.
"""
import torch._dynamo