mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Chatted with @stas00 on slack and here are some great improvements he suggested to the compile docs - [x] Rename `dynamo` folder to `compile` - [x] Link `compile` docstring on `torch.html` to main index page for compile - [x] Create a new index page that describes why people should care - [x] easy perf, memory reduction, 1 line - [x] Short benchmark table - [x] How to guide - [x] TOC that links to the more technical pages folks have written, make the existing docs we have a Technical overview - [x] Highlight the new APIs for `torch._inductor.list_options()` and `torch._inductor.list_mode_options()` - clarify these are inductor specific and add more prose around which ones are most interesting He also highlighted an interesting way to think about who is reading this doc we have - [x] End users, that just want things to run fast - [x] Library maintainers wrapping torch.compile which would care for example about understanding when in their code they should compile a model, which backends are supported - [x] Debuggers who needs are somewhat addressed by the troubleshooting guide and faq but those could be dramatically reworked to say what we expect to break And in a seperate PR I'll work on the below with @SherlockNoMad - [ ] Authors of new backends that care about how to plug into dynamo or inductor layer so need to explain some more internals like - [ ] IR - [ ] Where to plugin, dynamo? inductor? triton? Pull Request resolved: https://github.com/pytorch/pytorch/pull/96706 Approved by: https://github.com/svekars
150 lines
5.6 KiB
ReStructuredText
150 lines
5.6 KiB
ReStructuredText
TorchDynamo Deeper Dive
|
|
=======================
|
|
**Author**: `Jason Ansel <https://github.com/jansel>`_
|
|
|
|
What is a guard?
|
|
----------------
|
|
|
|
TorchDynamo operates just-in-time and specializes graphs based on
|
|
dynamic properties. For example, the first graph above has the following
|
|
guards:
|
|
|
|
::
|
|
|
|
GUARDS:
|
|
- local 'a' TENSOR_MATCH
|
|
- local 'b' TENSOR_MATCH
|
|
- global 'torch' FUNCTION_MATCH
|
|
|
|
If any of those guards fail, the graph will be recaptured and
|
|
recompiled. The interesting guard type there is ``TENSOR_MATCH``, which
|
|
checks the following ``torch.Tensor`` properties:
|
|
|
|
- Python class of the tensor (tensor subclassing, etc)
|
|
- dtype
|
|
- device
|
|
- requires_grad
|
|
- dispatch_key (with thread-local includes/excludes applied)
|
|
- ndim
|
|
- sizes\* (optional)
|
|
- strides\* (optional)
|
|
|
|
For sizes/strides you can disable this specialization by setting the
|
|
following parameter:
|
|
|
|
.. code-block:: python
|
|
|
|
torch._dynamo.config.dynamic_shapes = True
|
|
|
|
The full specialization mode allows the backend compiler to assume an
|
|
entirely static graph. Unfortunately, most backends require this.
|
|
Operators which return dynamic shapes will trigger a graph break when
|
|
not in dynamic shape mode.
|
|
|
|
What is Dynamo doing?
|
|
---------------------
|
|
|
|
If you want to understand better what TorchDynamo is doing, you can set:
|
|
|
|
.. code-block:: python
|
|
|
|
import torch._dynamo.config
|
|
import logging
|
|
|
|
torch._dynamo.config.log_level = logging.INFO
|
|
torch._dynamo.config.output_code = True
|
|
|
|
This code triggers useful (but spammy) printouts.
|
|
|
|
For example, the printouts for the first graph in the ``toy_example``
|
|
are:
|
|
|
|
::
|
|
|
|
__compiled_fn_0 <eval_with_key>.1
|
|
opcode name target args kwargs
|
|
------------- ------- ------------------------------------------------------ ---------------- --------
|
|
placeholder a a () {}
|
|
placeholder b b () {}
|
|
call_function abs_1 <built-in method abs of type object at 0x7f9ca082f8a0> (a,) {}
|
|
call_function add <built-in function add> (abs_1, 1) {}
|
|
call_function truediv <built-in function truediv> (a, add) {}
|
|
call_method sum_1 sum (b,) {}
|
|
call_function lt <built-in function lt> (sum_1, 0) {}
|
|
output output output ((truediv, lt),) {}
|
|
|
|
ORIGINAL BYTECODE toy_example example.py 9
|
|
10 0 LOAD_FAST 0 (a)
|
|
2 LOAD_GLOBAL 0 (torch)
|
|
4 LOAD_METHOD 1 (abs)
|
|
6 LOAD_FAST 0 (a)
|
|
8 CALL_METHOD 1
|
|
10 LOAD_CONST 1 (1)
|
|
12 BINARY_ADD
|
|
14 BINARY_TRUE_DIVIDE
|
|
16 STORE_FAST 2 (x)
|
|
|
|
11 18 LOAD_FAST 1 (b)
|
|
20 LOAD_METHOD 2 (sum)
|
|
22 CALL_METHOD 0
|
|
24 LOAD_CONST 2 (0)
|
|
26 COMPARE_OP 0 (<)
|
|
28 POP_JUMP_IF_FALSE 38
|
|
|
|
12 30 LOAD_FAST 1 (b)
|
|
32 LOAD_CONST 3 (-1)
|
|
34 BINARY_MULTIPLY
|
|
36 STORE_FAST 1 (b)
|
|
|
|
13 >> 38 LOAD_FAST 2 (x)
|
|
40 LOAD_FAST 1 (b)
|
|
42 BINARY_MULTIPLY
|
|
44 RETURN_VALUE
|
|
|
|
MODIFIED BYTECODE
|
|
9 0 LOAD_GLOBAL 3 (__compiled_fn_0)
|
|
2 LOAD_FAST 0 (a)
|
|
4 LOAD_FAST 1 (b)
|
|
6 CALL_FUNCTION 2
|
|
8 UNPACK_SEQUENCE 2
|
|
10 STORE_FAST 2 (x)
|
|
12 POP_JUMP_IF_FALSE 24
|
|
14 LOAD_GLOBAL 4 (__resume_at_30_1)
|
|
16 LOAD_FAST 1 (b)
|
|
18 LOAD_FAST 2 (x)
|
|
20 CALL_FUNCTION 2
|
|
22 RETURN_VALUE
|
|
>> 24 LOAD_GLOBAL 5 (__resume_at_38_2)
|
|
26 LOAD_FAST 1 (b)
|
|
28 LOAD_FAST 2 (x)
|
|
30 CALL_FUNCTION 2
|
|
32 RETURN_VALUE
|
|
|
|
GUARDS:
|
|
- local 'a' TENSOR_MATCH
|
|
- local 'b' TENSOR_MATCH
|
|
- global 'torch' FUNCTION_MATCH
|
|
|
|
At the top you can see the FX graph.
|
|
Next, you see the original bytecode of the function, followed by the
|
|
modified bytecode generated by TorchDynamo. Finally, you see the guards
|
|
which we covered above.
|
|
|
|
In the modified bytecode, ``__compiled_fn_0`` is the return value of
|
|
``my_compiler()`` (the compiled graph). ``__resume_at_30_1`` and
|
|
``__resume_at_38_2`` are both generated continuation functions that pick
|
|
up execution after a graph break (at bytecode offsets 30 and 38). Each
|
|
of these functions take the form:
|
|
|
|
::
|
|
|
|
__resume_at_<offset>:
|
|
... restore stack state if needed ...
|
|
JUMP_ABSOLUTE <offset> into toy_example
|
|
... original bytecode of toy_example ...
|
|
|
|
By generating this `resume_at` function, we force the remainder of the
|
|
function to be executed in a new Python frame which recursively
|
|
triggers TorchDynamo to restart its capture once execution reaches that
|
|
point for the first time.
|