Finally we have this PR to merge allow_in_graph/inline/skip trace rules into ```trace_rules.lookup_inner```, where we can define and lookup trace rules at both function level and file level. Going forward, this is the central place that we define and consulte Dynamo trace rule for any function.
* ```trace_rules.looup``` is the API can return allow_in_graph, inline or skip.
* ```skipfiles.check``` is the API can return inline or skip, since we have multiple places that only do inline/skip check.
* I'll move ```skipfiles.check``` to ```trace_rules.check``` as one of the follow-ups.
* Both functions consulte ```trace_rules.lookup_inner``` to get the tracing rule.
To avoid a single big PR, I left a few items as the follow-ups:
* Remove ```skipfiles.py``` and merge the code into ```trace_rules.py```.
* We do double check in ```symbolic_convert.check_inlineable```, will refactor and simplify it. We should only do inline/skip check before generating ```SkipFilesVariable``` and ```UserFunctionVariable```.
* Rename ```SkipFilesVariable``` as ```SkipFunctionVariable```, since we only handle functions.
* The inline/skip reasons are not logged for some cases, since the new lookup framework doesn't always return inline/skip reasons. I'll refactor loggings to record the inline/skip reason in next step.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118971
Approved by: https://github.com/jansel
```
def f():
def g():
return ()
print(g.__name__)
f()
```
The following script should print `g` (with or without torch.compile),
but prints `f.<locals>.g` with torch.compile.
The problem looks like we use the co_qualname when reconstructing the
NestedUserFunctionVariable. I switched this over to use the co_name.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118768
Approved by: https://github.com/yanboliang, https://github.com/jansel
The original motivation for MYPYINDUCTOR was a faster type checking configuration that only checked a subset of files. With the removal of `follow_imports = ignore`, we are now able to use dmypy to do fast incremental typechecking, eliminating the need for this.
Perhaps erroneously, when I tee'ed up this PR I elected to delete the `follow_imports = skip` designations in the mypy-inductor.ini. This lead to a number of extra type error suppressions that I manually edited. You will need to review.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118432
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418
This prepares the PR where we implement sets in terms of dicts.
To do so, rather than storing internally a dictionary that maps literals
to VariableTrackers, it stores (pretty much) a dictionary from VTs to VTs.
To do so, keys are wrapped in an opaque internal class _Hashable.
The Hashable class is opaque on purpose so that it fails hard if
if it inadvertently leaks back into user code.
We also found and fixed a number of latent bugs and inconsistencies
in the way dynamo checked what can be a dict key. More generally, we
make much clearer what are the things that need to be modified to add
a new supported key type to Dicts.
Fixes [#107595](https://www.internalfb.com/tasks?t=107595)
Fixes [#111603](https://www.internalfb.com/tasks?t=111603)
Re-PR of https://github.com/pytorch/pytorch/pull/111196 sadly due to reverts, we could not reuse @lezcano's original PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116785
Approved by: https://github.com/mlazos
Recent 2 triton PRs (https://github.com/openai/triton/pull/2701, https://github.com/openai/triton/pull/2756) change the interface for triton.compile, this PR added the necessary change on inductor side to work with both old and new compile API.
Also there is some simplification between compilation call in subprocess and the one in main process
- previously we pass warm_cache_only=True if the compilation happens in subprocess. But triton never use that argument in the currently used pin. So I removed that
- previously we only pass compute_capability if compilation happens in subprocess. The PR change that to always passing compute_capability to triton.compile no matter if the compilation happens in main or sub process.
Updated:
There are more interface change from triton side. E.g.
- tl.math.{min, max} now requires a propagate_nan argument
- JITFunction.run now requires a warmup argument. This affect the benchmarking phase of matmul max-autotune; on the other hand, JITFunction.run forbids stream argument now. Simply removing passing this in when benchmarking matmul triton kernel will work for both old and new version of triton.
- triton Autotuner change attribute name from 'warmup' to 'num_warmup' and from 'rep' to 'num_rep'. This cause dynamo failed to handle triton Autotuner object since dynamo TritonKernelVariable makes assumption about attribute names. It's used in some test cases that a model call triton Autotuner directly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115878
Approved by: https://github.com/jansel
I found whatever the attribute `async_op` in collective ops is `True` or `False` explicitly set by the users, it always leads to the graph break because the argument `async_op` is wrappered as `ConstantVariable(bool)` in dynamo. So here we need to use the `value` for the judgement.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115921
Approved by: https://github.com/jansel, https://github.com/wconstab
Recent 2 triton PRs (https://github.com/openai/triton/pull/2701, https://github.com/openai/triton/pull/2756) change the interface for triton.compile, this PR added the necessary change on inductor side to work with both old and new compile API.
Also there is some simplification between compilation call in subprocess and the one in main process
- previously we pass warm_cache_only=True if the compilation happens in subprocess. But triton never use that argument in the currently used pin. So I removed that
- previously we only pass compute_capability if compilation happens in subprocess. The PR change that to always passing compute_capability to triton.compile no matter if the compilation happens in main or sub process.
Updated:
There are more interface change from triton side. E.g.
- tl.math.{min, max} now requires a propagate_nan argument
- JITFunction.run now requires a warmup argument. This affect the benchmarking phase of matmul max-autotune; on the other hand, JITFunction.run forbids stream argument now. Simply removing passing this in when benchmarking matmul triton kernel will work for both old and new version of triton.
- triton Autotuner change attribute name from 'warmup' to 'num_warmup' and from 'rep' to 'num_rep'. This cause dynamo failed to handle triton Autotuner object since dynamo TritonKernelVariable makes assumption about attribute names. It's used in some test cases that a model call triton Autotuner directly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115878
Approved by: https://github.com/jansel
1. Removes calls to `replace_all` and `clone` and makes VTs mutable.
2. Properly handles Tuple Iterator mutation. Previously TupleIterator variables would only be properly reconstructed if they were advanced at least once in a frame. On calls to `next`, the source information would be lost (due to constructing a new iterator without using builder), which would ensure that during codegen the variable would be reconstructed from scratch. Now that VTs are mutated, the source is never lost, so we need to properly track mutation and handle it by replaying calls to `next` at the end of the modified bytecode.
3. Added test for checking iadd side effects, this was missing in our unit test coverage.
4. Fixed two incorrect sources, DelayGraphBreakVariable, and UserMethodVariable both relied on setting the source to AttrSource(parent, name) at the callsite of `var_getattr`.
5. Fixed a bug in inplace adding for lists, it would set the resulting VariableTracker's source to `None` which would utilize a different reconstruct path in codegen. Now this is handled explicitly by reconstructing vars when allow_cache=`False`, so that during side effect replay, the mutated var is correctly updated.
In subsequent PRs:
* Refactoring side effect tracking to be significantly simpler (I think we only need an `is_modified` flag)
* Refactor `next_variables` iterator to match the signature of `next`
* Remove all references to `options` in the code
* Refactor VTs representing mutable collections to implement their own mutation update handling
* Remove clone and/or make it specific to lists for creating slices
* Add mutation tracking/replay for sets
* Add mutation tracking/replay for iter.py
* Removing setting source in builder (it's set at the top level after a var is returned)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113725
Approved by: https://github.com/jansel
Updated version of #108885 addressing the review. In this PR:
- We add a VT.can_reconstruct utility that checks if VT.reconstruct()
does something.
- If functools.wraps(fn) is passed a `fn` that either has a source or
has .can_reconstruct() == True, then we stash the source (or the VT)
- Later on, we use the source (or VT.reconstruct) to actually
reconstruct the object in codegen.
Test Plan:
- New tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114279
Approved by: https://github.com/voznesenskym
This prepares the PR where we implement sets in terms of dicts.
To do so, rather than storing internally a dictionary that maps literals
to VariableTrackers, it stores (pretty much) a dictionary from VTs to VTs.
To do so, keys are wrapped in an opaque internal class `_Hashable`.
The Hashable class is opaque on purpose so that it fails hard if
if it inadvertently leaks back into user code.
We also found and fixed a number of latent bugs and inconsistencies
in the way dynamo checked what can be a dict key. More generally, we
make much clearer what are the things that need to be modified to add
a new supported key type to Dicts.
Fixes https://github.com/pytorch/pytorch/issues/107595
Fixes https://github.com/pytorch/pytorch/issues/111603
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111196
Approved by: https://github.com/jansel
The main thrust of the initial effort here was to capture `register_hook` calls on tensors in compile regions. The first part of this was done in https://github.com/pytorch/pytorch/pull/108903 wherein we added support for register_hook input tensors.
The distinction between input and intermediary is due to implementation differences.
There are 2 kinds of hooks:
1) Hooks on objects with sources (inputs, params)
2) Hooks on objects w/o sources (intermediaries, and outputs).
Note: As outputs can be made simple by how dynamo handles residuals, they could actually be handled as if they were inputs, but, for the sake of this PR, we will refer to hooks as either hooks on inputs (sourced), or hooks on intermediaries (not sourced).
**The plan:**
For tensors w/ a source: (The PR above)
We record registered hooks, store them as a global, and associate them with the tensor in residuals. This means that when dynamo goes to create the frame, where we produce bytecode to stitch together our PT2 modified bytecode with the original eager code, we call register_hook. This registration of hooks in residuals is sound because (a) it happens right after a Pt2 frame region ends and (b) we know that the tensor is alive in f_locals, f_globals, or a module in the users invoking frame. This means we can soundly know it will be around to invoke register_hook on. As long as we guard on the identity of the lifted function, this is sound to do.
For tensors w/o a source: (This PR)
Ostensibly, the most correct and complete solution would be to smuggle hooks into a runtime wrapper in aot_autograd, where all the items the hooks close over are lifted to inputs as necessary and passed alongside the user provided function. This is necessary so that we can properly trace out and capture all the mutations within the user defined hook at backwards time.
This is too complicated - so, we limited the scope of this initial PR to a simple subset of hooks:
- Hooks must have a source (be known to us already, not a lambda or intermediary defined function)
- We must be tracing under compiled autograd
**The flow**:
We use the HOP added in https://github.com/pytorch/pytorch/pull/109690/files, referred to as the HOP below.
1) We intercept register_hook calls and wrap the user defined fn in the HOP
2) We write a `_register_hook_trampoline` to the graph that is a local no-arg function that is invoked as a call_function in the dynamo graph
3) aot_autograd inlines through it during its trace, and sees the HOP
4) the HOP preserves itself in the graph - it does not get traced into
5) During backwards, compiled_autograd installs the HOP under a hook call
6) When compiled_autograd enters compilation over its generated graph, dynamo traces the contents of the hook
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109537
Approved by: https://github.com/ezyang
The strategy in this PR is pretty straightforward.
There are 2 kinds of hooks:
1) Hooks on objects with sources (inputs, params)
2) Hooks on objects w/o sources (intermediaries, and outputs).
Note: As outputs can be made simple by how dynamo handles residuals, they could actually be handled as if they were inputs, but, for the sake of this PR, we will refer to hooks as either hooks on inputs (sourced), or hooks on intermediaries (not sourced).
The plan:
**For tensors w/ a source:**
We record registered hooks, store them as a global, and associate them with the tensor in residuals. This means that when dynamo goes to create the frame, where we produce bytecode to stitch together our PT2 modified bytecode with the original eager code, we call `register_hook`. This registration of hooks in residuals is sound because (a) it happens right after a Pt2 frame region ends and (b) we know that the tensor is alive in f_locals, f_globals, or a module in the users invoking frame. This means we can soundly know it will be around to invoke `register_hook` on. As long as we guard on the identity of the lifted function, this is sound to do.
**For tensors w/o a source:**
Graph break - we will support this in a subsequent PR
**Handles:**
An interesting new component here is the creation of a `STORE_FAST `->`LOAD_FAST` associated with the handle, the return result of `register_hook`. If the user code stored the result of `register_hook` in a handle, we need to honor that. We do so by interceding into `STORE_FAST`, and recording the name of the local variable as directed by user code. We then honor that same name in the reconstructed bytecode. If the user did not store a hook, we merely pop the produced value to preserve the stack.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108903
Approved by: https://github.com/ezyang
ghstack dependencies: #108846, #109092
The strategy for supporting functools partials is relatively straightforward.
There are 2 cases we need to support:
**1) Functools partials as input**
In this case, we are first seeing the functools partial and it is guaranteed to have a source. As such, the args, keywords, and func of the functools partial are passed through VariableBuilder. As this is the first time we are seeing these objects (as it is an input), we re-enter VariableBuilder with a source referencing the args, keywords, and func as attributes of the input to produce:
- func: A callable VariableTracker (UDF, TorchVariable, etc) depending on the value of `func`
- args: List[VariableTracker] - note, not ListVariableTracker!
- keywords: Dict[str, VariableTracker]
A major benefit of this structure is that it very elegantly matches the args to `call_function`.
We then compose a FunctoolsPartialVariable from the VariableTrackers made above.
**2) Functools partials created within compile**
In this case, we already have all the args as known VTs, and thus just compose a FunctoolsPartialVariable as we do for case (1).
For both (1) and (2) - we propagate all guards from the func, args, and keyword VTs to the FunctoolsPartialVariable
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108846
Approved by: https://github.com/ezyang, https://github.com/jansel
When inlining a function which loads a closure, its direct parent may not load that closure. So we cannot find the closure name in parent's symbolic locals. In this PR, we fix it by recursively searching the parent instruction translator stack to resolve the closure.
**Background**
When developing https://github.com/pytorch/pytorch/pull/105679, this corner case is triggered. A small repro is added in the test of this pr, where outer is loaded by deep2 but not by deep.
```python
def test_inline_closure_not_loaded_by_parent(self):
def outer(a):
return a + 1
def indirect(x):
return direct(x)
def direct(x):
def deep2(c):
return outer(c)
def deep(c):
return deep2(c)
return deep(x)
x = torch.randn(3)
eager = indirect(x)
counter = CompileCounter()
compiled = torch._dynamo.optimize(counter)(indirect)(x)
```
Running the test, we have the following error before the PR:
```
Traceback (most recent call last):
File "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6584, in test_inline_closure_not_loaded_by_parent
compiled = torch._dynamo.optimize(counter)(indirect)(x)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 321, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 481, in catch_errors
return callback(frame, cache_size, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 543, in _convert_frame
result = inner_convert(frame, cache_size, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 130, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 362, in _convert_frame_assert
return _compile(
File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 194, in time_wrapper
r = func(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 531, in _compile
raise InternalTorchDynamoError(str(e)).with_traceback(e.__traceback__) from None
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 432, in _compile
out_code = transform_code_object(code, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 417, in transform
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2067, in run
super().run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 724, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 688, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1116, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 562, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 261, in call_function
return super().call_function(tx, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2172, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2279, in inline_call_
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 724, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 688, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1116, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 562, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2172, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2279, in inline_call_
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 724, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 688, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1116, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 562, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2172, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2227, in inline_call_
sub_locals, closure_cells = func.bind_args(parent, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 471, in bind_args
result[name] = parent.symbolic_locals[name]
torch._dynamo.exc.InternalTorchDynamoError: outer
from user code:
File "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6570, in indirect
return direct(x)
File "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6579, in direct
return deep(x)
File "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6577, in deep
return deep2(c)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
To execute this test, run the following from the base repo dir:
python test/dynamo/test_misc.py -k test_inline_closure_not_loaded_by_parent
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
---------------------------------------------------------------------------------------------------------------------------- Captured stdout call -----------------------------------------------------------------------------------------------------------------------------
frames [('total', 1)]
inline_call []
---------------------------------------------------------------------------------------------------------------------------- Captured stderr call -----------------------------------------------------------------------------------------------------------------------------
[2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping __init__ /home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py
[2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping __enter__ /home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py
[2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping helper /home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py
[2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping __init__ /home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py
[2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping __enter__ /home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/contextlib.py
[2023-08-02 15:48:36,560] torch._dynamo.eval_frame: [DEBUG] skipping enable_dynamic /home/yidi/local/pytorch/torch/_dynamo/eval_frame.py
[2023-08-02 15:48:36,561] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing indirect /home/yidi/local/pytorch/test/dynamo/test_misc.py:6569
TRACE starts_line indirect /home/yidi/local/pytorch/test/dynamo/test_misc.py:6569
def indirect(x):
[2023-08-02 15:48:36,591] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['x'] (3,) [<DimDynamic.STATIC: 2>] [None]
TRACE starts_line indirect /home/yidi/local/pytorch/test/dynamo/test_misc.py:6570
return direct(x)
[2023-08-02 15:48:36,594] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_DEREF direct []
[2023-08-02 15:48:36,594] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST x [UserFunctionVariable()]
[2023-08-02 15:48:36,594] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL_FUNCTION 1 [UserFunctionVariable(), TensorVariable()]
[2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] INLINING <code object direct at 0x7fbe4d366810, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6572>
TRACE starts_line direct /home/yidi/local/pytorch/test/dynamo/test_misc.py:6572 (inline depth: 1)
def direct(x):
TRACE starts_line direct /home/yidi/local/pytorch/test/dynamo/test_misc.py:6573 (inline depth: 1)
def deep2(c):
[2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CLOSURE outer []
[2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] TRACE BUILD_TUPLE 1 [InlinedClosureVariable()]
[2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST <code object deep2 at 0x7fbe4d3666b0, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6573> [TupleVariable()]
[2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST MiscTests.test_inline_closure_not_loaded_by_parent.<locals>.direct.<locals>.deep2 [TupleVariable(), ConstantVariable(code)]
[2023-08-02 15:48:36,595] torch._dynamo.symbolic_convert: [DEBUG] TRACE MAKE_FUNCTION 8 [TupleVariable(), ConstantVariable(code), ConstantVariable(str)]
[2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_DEREF deep2 [NestedUserFunctionVariable()]
TRACE starts_line direct /home/yidi/local/pytorch/test/dynamo/test_misc.py:6576 (inline depth: 1)
def deep(c):
[2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CLOSURE deep2 []
[2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE BUILD_TUPLE 1 [NewCellVariable()]
[2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST <code object deep at 0x7fbe4d366760, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6576> [TupleVariable()]
[2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_CONST MiscTests.test_inline_closure_not_loaded_by_parent.<locals>.direct.<locals>.deep [TupleVariable(), ConstantVariable(code)]
[2023-08-02 15:48:36,597] torch._dynamo.symbolic_convert: [DEBUG] TRACE MAKE_FUNCTION 8 [TupleVariable(), ConstantVariable(code), ConstantVariable(str)]
[2023-08-02 15:48:36,598] torch._dynamo.symbolic_convert: [DEBUG] TRACE STORE_FAST deep [NestedUserFunctionVariable()]
TRACE starts_line direct /home/yidi/local/pytorch/test/dynamo/test_misc.py:6579 (inline depth: 1)
return deep(x)
[2023-08-02 15:48:36,598] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST deep []
[2023-08-02 15:48:36,598] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST x [NestedUserFunctionVariable()]
[2023-08-02 15:48:36,598] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL_FUNCTION 1 [NestedUserFunctionVariable(), TensorVariable()]
[2023-08-02 15:48:36,598] torch._dynamo.symbolic_convert: [DEBUG] INLINING <code object deep at 0x7fbe4d366760, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6576>
TRACE starts_line deep /home/yidi/local/pytorch/test/dynamo/test_misc.py:6576 (inline depth: 2)
def deep(c):
TRACE starts_line deep /home/yidi/local/pytorch/test/dynamo/test_misc.py:6577 (inline depth: 2)
return deep2(c)
[2023-08-02 15:48:36,599] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_DEREF deep2 []
[2023-08-02 15:48:36,599] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST c [NestedUserFunctionVariable()]
[2023-08-02 15:48:36,599] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL_FUNCTION 1 [NestedUserFunctionVariable(), TensorVariable()]
[2023-08-02 15:48:36,599] torch._dynamo.output_graph: [DEBUG] restore_graphstate: removed 0 nodes
[2023-08-02 15:48:36,599] torch._dynamo.symbolic_convert: [DEBUG] FAILED INLINING <code object deep at 0x7fbe4d366760, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6576>
[2023-08-02 15:48:36,599] torch._dynamo.output_graph: [DEBUG] restore_graphstate: removed 0 nodes
[2023-08-02 15:48:36,599] torch._dynamo.symbolic_convert: [DEBUG] FAILED INLINING <code object direct at 0x7fbe4d366810, file "/home/yidi/local/pytorch/test/dynamo/test_misc.py", line 6572>
[2023-08-02 15:48:36,599] torch._dynamo.output_graph: [DEBUG] restore_graphstate: removed 0 nodes
```
Test Plan:
add new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106491
Approved by: https://github.com/williamwen42, https://github.com/jansel, https://github.com/zou3519
The main complexity comes from the __init__ function of Dataclass variables which look something like this
```
[2023-07-10 05:01:29,548] torch._dynamo.symbolic_convert: [DEBUG] INLINING <code object __init__ at 0x7f7015154450, file "<string>", line 2>
3 0 LOAD_FAST 1 (b)
2 LOAD_FAST 0 (self)
4 STORE_ATTR 0 (b)
4 6 LOAD_FAST 2 (named_tensors)
8 LOAD_DEREF 0 (_HAS_DEFAULT_FACTORY)
10 IS_OP 0
12 POP_JUMP_IF_FALSE 20
14 LOAD_DEREF 1 (_dflt_named_tensors)
16 CALL_FUNCTION 0
18 JUMP_FORWARD 2 (to 22)
>> 20 LOAD_FAST 2 (named_tensors)
>> 22 LOAD_FAST 0 (self)
24 STORE_ATTR 1 (named_tensors)
26 LOAD_CONST 0 (None)
28 RETURN_VALUE
```
There are multiple issues
* VariableBuilder call in functions.py was wrong. We were calling *options as args.
* We were not setting source while tracking the new object. This led to no source for Dataclass variable, which has some new variables in its closures as seen in the above bytecode.
* There is IS_OP in above bytecode, which brings more cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104840
Approved by: https://github.com/jansel
Fix https://github.com/pytorch/pytorch/issues/99639 by handling the case in `InliningInstructionTranslator`'s `LOAD_CLOSURE` definition when the requested cell is not in `self.closure_cells`.
My intuition is that the behavior of `LOAD_DEREF` and `STORE_DEREF` on a cell/freevar should not depend on whether or not we called `LOAD_CLOSURE` (that is, we shouldn't create a new cell var in `LOAD_CLOSURE` like in https://github.com/pytorch/pytorch/pull/101357). But we need a way to push cells created by the inlined function that were not present in the caller - `InlinedClosureVariable` is used to differentiate these cells from other cells.
Adding this test causes an error though (EDIT: this test is not relevant to this PR and instead just reveals that `cond` with Python side effects is still broken):
```python
def test_closure_out_of_scope_cell_with_cond(self):
from functorch.experimental.control_flow import cond
cell1 = torch.rand(3, 3)
cell2 = torch.rand(3, 3)
orig3 = torch.rand(3, 3)
def test(x):
cell3 = orig3.clone()
def then():
nonlocal cell3
cell3 += cell1
return cell3
def els():
nonlocal cell3
cell3 += cell2
return cell3
return cond(x > 0, then, els, [])
opt_fn = torch._dynamo.optimize("eager")(test)
result1 = opt_fn(1)
self.assertTrue(torch.allclose(result1, orig3 + cell1))
result2 = opt_fn(-1)
self.assertTrue(torch.allclose(result1, orig3 + cell1 + cell2))
```
```
Traceback (most recent call last):
File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1768, in test_closure_out_of_scope_cell_with_cond
result1 = opt_fn(1)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 295, in _fn
return fn(*args, **kwargs)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 448, in catch_errors
return callback(frame, cache_size, hooks, frame_state)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 526, in _convert_frame
result = inner_convert(frame, cache_size, hooks, frame_state)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 127, in _fn
return fn(*args, **kwargs)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert
return _compile(
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/utils.py", line 180, in time_wrapper
r = func(*args, **kwargs)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 430, in _compile
out_code = transform_code_object(code, transform)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
transformations(instructions, code_options)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 415, in transform
tracer.run()
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2029, in run
super().run()
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run
and self.step()
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step
getattr(self, inst.opname)(inst)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 391, in wrapper
return inner_fn(self, inst)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 1100, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 559, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1061, in call_function
(false_r, false_graph, false_lifted_freevars) = speculate_branch(False)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1044, in speculate_branch
ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph(
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 850, in speculate_subgraph
output = f.call_function(tx, args, {})
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/functions.py", line 121, in call_function
return tx.inline_user_function_return(
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 595, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2134, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2231, in inline_call_
tracer.run()
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run
and self.step()
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step
getattr(self, inst.opname)(inst)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 162, in impl
self.push(fn_var.call_function(self, self.popn(nargs), {}))
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/builtin.py", line 497, in call_function
proxy = tx.output.create_proxy(
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 345, in create_proxy
return self.current_tracer.create_proxy(*args, **kwargs)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1109, in create_proxy
new_arg = self.lift_tracked_freevar_to_input(arg)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1226, in lift_tracked_freevar_to_input
self.parent.lift_tracked_freevar_to_input(proxy)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1219, in lift_tracked_freevar_to_input
assert (
AssertionError: lift_tracked_freevar_to_input on root SubgraphTracer
from user code:
File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1766, in test
return cond(x > 0, then, els, [])
File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1764, in els
cell3 += cell2
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104222
Approved by: https://github.com/jansel