mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Beef up the allow_in_graph docs (#127117)
We make the following changes: - most of the time when someone uses allow_in_graph, they actually wanted to make a custom op. We add a link to the custom ops landing page and explain the differences between allow_in_graph and custom ops. - we warn people against using allow_in_graph footguns and document them. Test Plan: - tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/127117 Approved by: https://github.com/jansel, https://github.com/albanD
This commit is contained in:
parent
e24a87ed8d
commit
08653fe355
|
|
@ -74,22 +74,12 @@ def assume_constant_result(fn):
|
|||
|
||||
def allow_in_graph(fn):
|
||||
"""
|
||||
Customize which functions TorchDynamo will include in the generated
|
||||
graph. Similar to `torch.fx.wrap()`.
|
||||
::
|
||||
Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function
|
||||
and instead directly write it to the graph when encountered.
|
||||
|
||||
torch._dynamo.allow_in_graph(my_custom_function)
|
||||
See :func:`torch.compiler.allow_in_graph`'s docstring for the full documentation
|
||||
|
||||
@torch._dynamo.optimize(...)
|
||||
def fn(a):
|
||||
x = torch.add(x, 1)
|
||||
x = my_custom_function(x)
|
||||
x = torch.add(x, 1)
|
||||
return x
|
||||
|
||||
fn(...)
|
||||
|
||||
Will capture a single graph containing `my_custom_function()`.
|
||||
WARNING: this API can be a footgun, please read the documentation carefully.
|
||||
"""
|
||||
if isinstance(fn, (list, tuple)):
|
||||
return [allow_in_graph(x) for x in fn]
|
||||
|
|
|
|||
|
|
@ -32,22 +32,77 @@ def reset() -> None:
|
|||
|
||||
def allow_in_graph(fn):
|
||||
"""
|
||||
Customize which functions compilation will include in the generated graph.
|
||||
It bypasses all introspection of the symbolic python code in favor of
|
||||
directly writing it to the graph.
|
||||
If fn is a list or tuple of callables it recursively applies :func:`allow_in_graph()`
|
||||
to each function and returns a new list or tuple containing the modified functions
|
||||
Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function
|
||||
and instead directly write it to the graph when encountered.
|
||||
|
||||
Args:
|
||||
fn: A callable representing the function to be included in the graph.
|
||||
If you are using :func:`torch.compile` (with backend="inductor" (the default)), or
|
||||
:func:`torch.export.export`, and trying to black-box a Python function throughout
|
||||
all tracing, do not use this API.
|
||||
Instead, please create a custom operator (see :ref:`custom-ops-landing-page`)
|
||||
|
||||
.. warning::
|
||||
|
||||
:func:`allow_in_graph` skips TorchDynamo completely on the decorated function
|
||||
skipping all TorchDynamo safety checks (graph breaks, handling closures, etc).
|
||||
Therefore, one has to be very careful with :func:`allow_in_graph` since subsystems
|
||||
like AOT Autograd rely on torchdynamo
|
||||
If not careful, this could lead to soundness and really hard-to-debug issues.
|
||||
If you're a typical torch.compile user (e.g. you're applying torch.compile to
|
||||
a model to make it run faster), you probably don't want to use this function.
|
||||
:func:`allow_in_graph` is a footgun because it skips the compiler frontend
|
||||
(Dynamo) that is responsible for doing safety checks (graph breaks, handling
|
||||
closures, etc). Incorrect usage will lead to difficult-to-debug silent
|
||||
incorrectness issues.
|
||||
|
||||
Given a Python function with no allow_in_graph decorator, regular execution
|
||||
of torch.compile traces through the function. :func:`allow_in_graph` changes
|
||||
it so that the frontend does not trace inside the function, but the compiler
|
||||
backend still traces through it. Compare this to custom operators, which
|
||||
treats a function as a black box throughout the torch.compile stack. The following
|
||||
table compares these mechanisms.
|
||||
|
||||
+------------------------+-----------------------+--------------------------------+
|
||||
| Mechanism | Frontend (Dynamo) | Backend (AOTAutograd+Inductor) |
|
||||
+========================+=======================+================================+
|
||||
| no decorator | trace inside | trace inside |
|
||||
+------------------------+-----------------------+--------------------------------+
|
||||
| allow_in_graph | opaque callable | trace inside |
|
||||
+------------------------+-----------------------+--------------------------------+
|
||||
| custom op | opaque callable | opaque callable |
|
||||
+------------------------+-----------------------+--------------------------------+
|
||||
|
||||
One common use case for :func:`allow_in_graph()` is as an escape hatch for the compiler
|
||||
frontend: if you know the function works w.r.t. to the downstream components of the
|
||||
compilation stack (AOTAutograd and Inductor) but there is a Dynamo bug that prevents it from
|
||||
symbolically introspecting the function properly (or if your code is in C/C++ and
|
||||
therefore cannot be introspected with Dynamo), then one can decorate said function
|
||||
with :func:`allow_in_graph` to bypass Dynamo.
|
||||
|
||||
We require that ``fn`` adhere to the following restrictions. Failure to adhere
|
||||
results in undefined behavior:
|
||||
|
||||
- The inputs to ``fn`` must be Proxy-able types in the FX graph. Valid types include:
|
||||
Tensor/int/bool/float/None/List[Tensor?]/List[int?]/List[float?]
|
||||
Tuple[Tensor?, ...]/Tuple[int?, ...]/Tuple[float?, ...]/torch.dtype/torch.device
|
||||
- The outputs to ``fn`` must be Proxy-able types in the FX graph (see previous bullet)
|
||||
- all Tensors used inside of ``fn`` must be passed directly as inputs to ``fn``
|
||||
(as opposed to being captured variables).
|
||||
|
||||
Args:
|
||||
fn: A callable representing the function to be included in the graph.
|
||||
If ``fn`` is a list or tuple of callables it recursively applies
|
||||
:func:`allow_in_graph()` to each function and returns a new list or
|
||||
tuple containing the modified functions.
|
||||
|
||||
Example::
|
||||
|
||||
torch.compiler.allow_in_graph(my_custom_function)
|
||||
|
||||
@torch.compile(...)
|
||||
def fn(a):
|
||||
x = torch.add(x, 1)
|
||||
x = my_custom_function(x)
|
||||
x = torch.add(x, 1)
|
||||
return x
|
||||
|
||||
fn(...)
|
||||
|
||||
Will capture a single graph containing ``my_custom_function()``.
|
||||
|
||||
"""
|
||||
import torch._dynamo
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user