mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[JIT] Add Exit Transform / Convert To SSA to docs
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/24114 Differential Revision: D19780828 Pulled By: eellison fbshipit-source-id: d481ad886b2ad6349a1646672e507336d45759fb
This commit is contained in:
parent
b0476dc6e6
commit
ca33aeba09
|
|
@ -41,6 +41,7 @@ Sections start with a reference to the source file where the code related to the
|
||||||
- [SugaredValue](#sugaredvalue)
|
- [SugaredValue](#sugaredvalue)
|
||||||
- [Resolver](#resolver)
|
- [Resolver](#resolver)
|
||||||
- [Environment](#environment)
|
- [Environment](#environment)
|
||||||
|
- [SSA Conversion](#convert_to_ssa)
|
||||||
- [Python-Compiler Interaction](#python-compiler-interaction)
|
- [Python-Compiler Interaction](#python-compiler-interaction)
|
||||||
- [Executing Programs](#executing-programs)
|
- [Executing Programs](#executing-programs)
|
||||||
- [Evaluation Semantics](#evaluation-semantics)
|
- [Evaluation Semantics](#evaluation-semantics)
|
||||||
|
|
@ -201,7 +202,7 @@ Iterators for the `nodes()` list are invalided when the current Node they point
|
||||||
|
|
||||||
Block also contain a list of input and output values. The meaning of these values depends on where the block is used. For the Graph's top-level block, these are inputs and outputs to the Graph, and line up with the FunctionSchema associated with a Method.
|
Block also contain a list of input and output values. The meaning of these values depends on where the block is used. For the Graph's top-level block, these are inputs and outputs to the Graph, and line up with the FunctionSchema associated with a Method.
|
||||||
|
|
||||||
**Control-flow** is represented with using sub-blocks rather than a control-flow graph representation. A `prim::If` has one block for the true branch and one block for the else.A `prim:Loop` has a block for the loop body (there is no condition block, instead the end of the loop body computes whether to re-enter the loop body). This representation ensures we have structured control-flow. Currently TorchScript does not allow for early returns, breaking out of loops early. This limitation makes a lot of optimizations easier and is true for the vast majority of networks. Our frontend permits certain forms of syntax sugar that allow a limited amount of re-writing of if statements to avoid needing to support early returns. A Node can lookup what Block it is in, and a Block and can look up its parent (either the Node that has it as a subblock, or `nullptr` for the main Block).
|
**Control-flow** is represented with using sub-blocks rather than a control-flow graph representation. A `prim::If` has one block for the true branch and one block for the else.A `prim:Loop` has a block for the loop body (there is no condition block, instead the end of the loop body computes whether to re-enter the loop body). This representation ensures we have structured control-flow. This limitation makes a lot of optimizations easier and is true for the vast majority of networks. A Node can lookup what Block it is in, and a Block and can look up its parent (either the Node that has it as a subblock, or `nullptr` for the main Block).
|
||||||
|
|
||||||
### If ###
|
### If ###
|
||||||
For if-statements (`prim::If`) the Blocks have no inputs, and the outputs are the new values of variables in the outer block whose values were altered in an if-statement.
|
For if-statements (`prim::If`) the Blocks have no inputs, and the outputs are the new values of variables in the outer block whose values were altered in an if-statement.
|
||||||
|
|
@ -541,7 +542,101 @@ This makes it possible to use most of the compiler functionality when python is
|
||||||
|
|
||||||
[script/compiler.cpp](../script/compiler.cpp)
|
[script/compiler.cpp](../script/compiler.cpp)
|
||||||
|
|
||||||
The Environment object tracks the assignment of variable names to SugaredValues during compilation. It is local to the compiler file. A stack of environments exist, with a new environment being created for sub-blocks introduced by control flow. The Environment also handles turning the AST representation into SSA-form by tracking which variables were modified inside a sub-block and inserting the correct inputs/outputs to the Blocks of if-statements and loops.
|
The Environment object tracks the assignment of variable names during compilation. It is local to the compiler file. A stack of environments exist, with a new environment being created for sub-blocks introduced by control flow. The Environment keeps two tables, one for values which are not first class in the type system (Sugared values) and a type table for values which are. When first class values are set, we emit a prim::Store, and when they are referenced we emit a prim::Load. Sugared values are not re-assignable. The graph is converted to SSA in the convertToSSA pass.
|
||||||
|
|
||||||
|
## Conversion To SSA ##
|
||||||
|
|
||||||
|
[script/convert_to_ssa.cpp](../script/convert_to_ssa.cpp)
|
||||||
|
|
||||||
|
As explained in the * Block * section, the IR is represented in structured control flow composed of ifs & loops. This makes it easier to optimize and lower to other compilers which do not support unstructured control flow. We lower python control flow (break, continue, return) to this simplified form. We do closing over any variables in the environment, so we are able to convert all writes and reads from the environment directly to SSA form.
|
||||||
|
|
||||||
|
Conversion to SSA works in multiple parts.
|
||||||
|
- First, we add loads and stores to control flow operators (ifs & loops).
|
||||||
|
- Then we erase Break & Continue statements from the graph and replace them with `prim::LoopContinuation`. `prim::LoopContinuation` has the form `LoopContinuation(%loop_continue_condition, %loop_carried_vars)`. Break Statements have the continue condition set to false, and Continue statements inline the loop condition. %loop_carried_vars are the loop carried variables of the inner most loop that contains the Break or Continue statement, are added by inserting prim::Loads calls at the location of the statement.
|
||||||
|
- Then we inline the loop condition into the graph loops.
|
||||||
|
- Next we erase loads and stores, removing all Stores and replacing all loads
|
||||||
|
with whatever the in-scope value of the variable name is.
|
||||||
|
- Finally, we remove `prim::LoopContinuation`s and `prim::ReturnStmt`s in the exit_transform pass.
|
||||||
|
|
||||||
|
## Exit Transform ##
|
||||||
|
|
||||||
|
[script/exit_transform.cpp](../script/exit_transform.cpp)
|
||||||
|
|
||||||
|
This pass takes in a graph where LoopContinuation & ReturnStmts exist in the graph and erases them, correctly setting block outputs. `prim::LoopContinuation(*vals)` means that the values are targeting the most recent loop block. `prim::ReturnStmt(*vals)` means that the values are targeting the most recent Closure or Graph Block.
|
||||||
|
|
||||||
|
If a block has an exit node, no further instructions will be executed until the exit target has been reached. If we encounter a node that contains nested blocks that may have hit an exit node, such as an if statement that exits in one block and does not exit in the other, we use a boolean value to indicate if the exit has been hit or not. Then, we conditionalize further execution.
|
||||||
|
|
||||||
|
Python example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
while i < 5:
|
||||||
|
if i == 3:
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
i += 2
|
||||||
|
```
|
||||||
|
|
||||||
|
-> transforms to
|
||||||
|
|
||||||
|
```python
|
||||||
|
continue_loop = i < 5
|
||||||
|
while continue_loop:
|
||||||
|
if i == 3:
|
||||||
|
i = i + 1
|
||||||
|
continue_loop = i < 5
|
||||||
|
did_exit = True
|
||||||
|
if did_exit:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
i = i + 2
|
||||||
|
continue_loop = i < 5
|
||||||
|
```
|
||||||
|
|
||||||
|
The pass also keeps track of nodes or blocks that will always throw Exceptions so that we do not unnecessarily conditionalize execution. In the following example, we can treat the if statement as always Returning and remove the `print` statement.
|
||||||
|
|
||||||
|
```python
|
||||||
|
if i < 0:
|
||||||
|
raise Exception("Negative input")
|
||||||
|
else:
|
||||||
|
return math.sqrt(i)
|
||||||
|
print(i) # unreachable code
|
||||||
|
```
|
||||||
|
|
||||||
|
In the above example, the if statement will have one output, with the value on the false branch being `math.sqrt(i)`. In the true branch, insert and use
|
||||||
|
`prim::Uninitialized`. These are values inserted by the compiler when it can prove the value will never be used. It can be introduced by exceptions, breaks, continues, and returns.
|
||||||
|
|
||||||
|
We initially considered doing the Transform pass before Loads and Stores were removed from the graph. However, this breaks when a loop carried variable
|
||||||
|
is captured in a break or continue and then is refined in the rest of the loop body. In the below example, at the point of the `continue`, `x` has type `Optional[int]` but is refined to `int` after the continue statement.
|
||||||
|
|
||||||
|
```python
|
||||||
|
...
|
||||||
|
if cond:
|
||||||
|
if i < 3:
|
||||||
|
x = torch.jit.annotate(Optional[int], None)
|
||||||
|
continue
|
||||||
|
x = 1
|
||||||
|
else:
|
||||||
|
x = 2
|
||||||
|
print(x)
|
||||||
|
```
|
||||||
|
If we were to rearrange the graph before loads & stores were removed:
|
||||||
|
|
||||||
|
```python
|
||||||
|
if cond:
|
||||||
|
if i < 3:
|
||||||
|
x = torch.jit.annotate(Optional[int], None)
|
||||||
|
did_continue = True
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
did_continue = False
|
||||||
|
if not did_continue:
|
||||||
|
x = 1
|
||||||
|
else:
|
||||||
|
x = 2
|
||||||
|
if not did_continue:
|
||||||
|
print(x)
|
||||||
|
```
|
||||||
|
The type of `x` at the print statement would be `Optional[int]`, which breaks its original type.
|
||||||
|
|
||||||
## Python-Compiler Interaction ##
|
## Python-Compiler Interaction ##
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user