We want to get to a point where most UserErrors link to exportdb examples. This PR makes passing case names non-optional to make this intent clearer and encourage developers who raise UserErrors to make or point to examples that make fixing such errors more obvious for users.
In addition, sometimes there are multiple examples that are relevant to an error. Thus this PR also enables passing multiple case names.
Retry of #110733 which was reverted due to a landrace.
Differential Revision: [D50087148](https://our.internmc.facebook.com/intern/diff/D50087148/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110878
Approved by: https://github.com/gmagogsfm, https://github.com/tugsbayasgalan
We want to get to a point where most `UserError`s link to `exportdb` examples. This PR makes passing case names non-optional to make this intent clearer and encourage developers who raise `UserError`s to make or point to examples that make fixing such errors more obvious for users.
In addition, sometimes there are multiple examples that are relevant to an error. Thus this PR also enables passing multiple case names.
Differential Revision: [D50020465](https://our.internmc.facebook.com/intern/diff/D50020465/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110733
Approved by: https://github.com/zhxchen17
Ideally all `_dynamo.exc.UserError`s should have "case names", i.e., link to examples in `exportdb`.
This PR adds case names to several instances of `_dynamo.exc.UserError`. In particular, looking at coverage based on `UserErrorType`:
* `DYNAMIC_CONTROL_FLOW`, `ANTI_PATTERN`, and `STANDARD_LIBRARY` are fully covered.
* `CONSTRAINT_VIOLATION` and `DYNAMIC_DIM` have no coverage. We don't seem to have any dedicated examples of specifying dynamic shapes in `exportdb` (although they are used in some other examples without explanation, to avoid some specialization that would make such examples moot).
* `INVALID_INPUT` is only partly covered. Frankly this is tedious to cover via examples.
Differential Revision: [D49928518](https://our.internmc.facebook.com/intern/diff/D49928518/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110555
Approved by: https://github.com/angelayi, https://github.com/ydwu4
Requested from @tugsbayasgalan: we want dynamo to preserve some FX node metadata when we trace `GraphModule`s (`nn_module_stack`, `source_fn`, `stack_trace`). This is helpful for the case when we export an aten-level `GraphModule`, add some (possibly non-torch or non-aten) ops, and we want to transform the graph back into an aten-level graph. Without preserving metadata, future passes that look at metadata (e.g. quantization passes) won't work.
This feature also has the additional benefit of being able to preserve origin line of code when `print_readable`'ing a `GraphModule`. This is helpful when debugging graphs that have passed through dynamo several times.
The added unit test demonstrates the added functionality of this PR.
~This PR is currently a proof-of-concept implementation that shows that preserving node metadata across dynamo is possible.~ This PR preserves node metadata across dynamo by doing the following:
- ~inject a counter variable into the `GraphModule` source code, which is incremented every time a node is run~
- Construct a line number -> node index map in `GraphModule` as the source code is being generated.
- pass a list of node metadata and the line number map to dynamo's bytecode analyzer
- ~dynamo traces the counter as a `ConstantVariable`, so when we create a new proxy, we can determine which original node index this proxy corresponds by looking at the value of the traced counter~
- When we create a new proxy, get the current instruction's line number, and get the node index using the line number map
- index into the original node metadata ~using the counter variable's tracked value.~
~Some things that should be addressed off the top of my head:~
- ~Is this feature even desirable? (Do we really want Dynamo to have special behavior for `GraphModules`? Should we expect users to re-export `GraphModules`?)~
- ~Is there a better approach than to use a counter? We considered using node names, line numbers, and assuming that proxies are created in the same order as the nodes, but each of these 3 have shortcomings. For node names, we only have access to new node names, not the old ones. Using line number is fragile. The third is problematic since not all created nodes go through `create_proxy` (e.g. inputs). We currently generate a line number to node index map when the `GraphModule`'s code is generated.~
- ~What's the best way to send data across the "CPython gap"? That is, it is not obvious how to cleanly pass data from dynamo's `eval_frame.py:_TorchDynamoContext.__call__` to `symbolic_convert.py:InstructionTranslatorBase.__init__`. In this PR, we use a global.~
Differential Revision: [D49257108](https://our.internmc.facebook.com/intern/diff/D49257108)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107067
Approved by: https://github.com/jansel
before the PR, for HF's ModelOutput class, we use dicts.py/DataClassVariable with our own implementation on __getItem__, __setAttr__, __setItem__. There is a risk that ModelOutput logic may change since it is a user code
after the PR, we inline __getItem__, __setAttr__, __setItem__ using dicts.py/CustomizedDictVariable so the logic always keep AA
unit test
* python test/dynamo/test_model_output.py -k test_HF_bert_model_output
test on HF benchmark
* python benchmarks/dynamo/huggingface.py -d cuda --inference --accuracy --progress --inductor --print-dataframe-summary 2>&1
* all metric are the same before/after the PR, including pass rate, unique_graphs, graph_breaks, unique_graph_breaks
* before the PR: P790393916
* after the PR: P790368991
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105044
Approved by: https://github.com/jansel
The strategy in this PR is pretty straightforward.
There are 2 kinds of hooks:
1) Hooks on objects with sources (inputs, params)
2) Hooks on objects w/o sources (intermediaries, and outputs).
Note: As outputs can be made simple by how dynamo handles residuals, they could actually be handled as if they were inputs, but, for the sake of this PR, we will refer to hooks as either hooks on inputs (sourced), or hooks on intermediaries (not sourced).
The plan:
**For tensors w/ a source:**
We record registered hooks, store them as a global, and associate them with the tensor in residuals. This means that when dynamo goes to create the frame, where we produce bytecode to stitch together our PT2 modified bytecode with the original eager code, we call `register_hook`. This registration of hooks in residuals is sound because (a) it happens right after a Pt2 frame region ends and (b) we know that the tensor is alive in f_locals, f_globals, or a module in the users invoking frame. This means we can soundly know it will be around to invoke `register_hook` on. As long as we guard on the identity of the lifted function, this is sound to do.
**For tensors w/o a source:**
Graph break - we will support this in a subsequent PR
**Handles:**
An interesting new component here is the creation of a `STORE_FAST `->`LOAD_FAST` associated with the handle, the return result of `register_hook`. If the user code stored the result of `register_hook` in a handle, we need to honor that. We do so by interceding into `STORE_FAST`, and recording the name of the local variable as directed by user code. We then honor that same name in the reconstructed bytecode. If the user did not store a hook, we merely pop the produced value to preserve the stack.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108903
Approved by: https://github.com/ezyang
ghstack dependencies: #108846, #109092
Summary:
Original commit changeset: e11cddf1fecc
Original Phabricator Diff: D49064185
Test Plan:
Comparing PT1 and PT2 performance on the IG Feed Model with this diff backed out: N4274204
Comparing the PT1 and PT2 performance on IG Feed with this diff committed: N4271093
Reviewed By: zou3519
Differential Revision: D49230047
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109199
Approved by: https://github.com/zou3519, https://github.com/xw285cornell
Requested from @tugsbayasgalan: we want dynamo to preserve some FX node metadata when we trace `GraphModule`s (`nn_module_stack`, `source_fn`, `stack_trace`). This is helpful for the case when we export an aten-level `GraphModule`, add some (possibly non-torch or non-aten) ops, and we want to transform the graph back into an aten-level graph. Without preserving metadata, future passes that look at metadata (e.g. quantization passes) won't work.
This feature also has the additional benefit of being able to preserve origin line of code when `print_readable`'ing a `GraphModule`. This is helpful when debugging graphs that have passed through dynamo several times.
The added unit test demonstrates the added functionality of this PR.
~This PR is currently a proof-of-concept implementation that shows that preserving node metadata across dynamo is possible.~ This PR preserves node metadata across dynamo by doing the following:
- ~inject a counter variable into the `GraphModule` source code, which is incremented every time a node is run~
- Construct a line number -> node index map in `GraphModule` as the source code is being generated.
- pass a list of node metadata and the line number map to dynamo's bytecode analyzer
- ~dynamo traces the counter as a `ConstantVariable`, so when we create a new proxy, we can determine which original node index this proxy corresponds by looking at the value of the traced counter~
- When we create a new proxy, get the current instruction's line number, and get the node index using the line number map
- index into the original node metadata ~using the counter variable's tracked value.~
~Some things that should be addressed off the top of my head:~
- ~Is this feature even desirable? (Do we really want Dynamo to have special behavior for `GraphModules`? Should we expect users to re-export `GraphModules`?)~
- ~Is there a better approach than to use a counter? We considered using node names, line numbers, and assuming that proxies are created in the same order as the nodes, but each of these 3 have shortcomings. For node names, we only have access to new node names, not the old ones. Using line number is fragile. The third is problematic since not all created nodes go through `create_proxy` (e.g. inputs). We currently generate a line number to node index map when the `GraphModule`'s code is generated.~
- ~What's the best way to send data across the "CPython gap"? That is, it is not obvious how to cleanly pass data from dynamo's `eval_frame.py:_TorchDynamoContext.__call__` to `symbolic_convert.py:InstructionTranslatorBase.__init__`. In this PR, we use a global.~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107067
Approved by: https://github.com/jansel
Fixes#106893
There are two main changes:
- Before this PR, the function returned by once_differentiable was
included in skipfiles (because its .co_code is
torch/autograd/function.py). This PR adds a mechanism to tell Dynamo
to inline a function, no matter if it is included in skipfiles.
- A bugfix: when we are introspecting the backward, we need to turn the
grad mode off. This is to accurately model the eager-mode semantics:
In eager-mode PyTorch, if second-order gradients were not requested, then
the grad mode is off. torch.compile does not work with higher-order
gradients and just assumes we do first-order gradients, so this is OK.
Test Plan:
- new test
Differential Revision: [D49064185](https://our.internmc.facebook.com/intern/diff/D49064185)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108686
Approved by: https://github.com/voznesenskym
**This PR is a 99% copy paste of Sam Gross** (@colesbury) work at https://github.com/pytorch/pytorch/pull/100642. Copied from there
--------
The NN_MODULE guard now subsumes guards on Module attributes. The check_fn will fail if the module attributes are changed (such as Module.training), parameters, submodules, and buffers are added or removed, and if fields are changed on the type itself.
This gives up specificity in the guard check -- if any field is changed the check_fn fails -- for faster overall checks.
-----
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108528
Approved by: https://github.com/ezyang
Marks all params/optimizer state as static addresses and a finalizer which cleans up the graph attributes when the optimizer goes out of scope.
**Note: this does not mark grads as static because this will increase memory usage significantly
There are two cases:
1. The upstream graph is cudagraphed - this case will work fine OOTB
2. The upstream graph is not cudagraphed - in this case, there will be a lot of copies introduced from the upstream (to copy the grads) into cudagraphed-owned memory, unless the user explicitly marks the grads as static. If the user does this, this will also require not deallocating the grads in zero_grad() (either the mod or optimizer version) by setting them to zero vs None. There is a PR (https://github.com/pytorch/pytorch/pull/107853) in flight to throw an error if zero_grad attempts to set static grads to None.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107504
Approved by: https://github.com/eellison
We want cond to always throw errors despite user's torch.compile mode.
The current implementation is to
1. catch the UserError.GRAPH_BREAK_IN_CONTROL_FLOW and once saw it, we directly raise: once in [break_graph_if_unsupported](bad3f2db40/torch/_dynamo/symbolic_convert.py (L1250)), which catches and raises for call_function (entry point of higher order operator) and a few others.
2. The raised exception is caught and raised again in [step](bad3f2db40/torch/_dynamo/symbolic_convert.py (L691)), where all instructions' exceptions are handled.
3. At the top-level, we treat it like an hard error and not supressing the errors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108027
Approved by: https://github.com/zou3519
ghstack dependencies: #108025, #108026
The new guard printout looks like this:
```
[DEBUG] GUARDS:
[DEBUG] ___check_type_id(L['name'], 7605632) # if name == "special_attr": # test/dynamo/test_misc.py:1155 in __getattribute__
[DEBUG] L['name'] == '_backward_pre_hooks' # if name == "special_attr": # test/dynamo/test_misc.py:1155 in __getattribute__
[DEBUG] ___check_obj_id(L['self'], 139746432564960) # return super().__getattribute__(name) # test/dynamo/test_misc.py:1157 in __getattribute__
[DEBUG] ___check_obj_id(L['__class__'], 1451499216) # return super().__getattribute__(name) # test/dynamo/test_misc.py:1157 in __getattribute__
[DEBUG] ___is_grad_enabled() # _dynamo/output_graph.py:346 in init_ambient_guards
[DEBUG] not ___are_deterministic_algorithms_enabled() # _dynamo/output_graph.py:342 in init_ambient_guards
[DEBUG] ___is_torch_function_enabled() # _dynamo/output_graph.py:350 in init_ambient_guards
[DEBUG] utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:348 in init_ambient_guards
```
Along with the guards, we also print what line of user code caused the guard to be added, or what line of Dynamo internal code added the guard (if there is no user stack trace, which is typically the case for ambient guards.)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107505
Approved by: https://github.com/mlazos, https://github.com/voznesenskym, https://github.com/anijain2305
Previously, you would get an error like
```
Dynamo input and output is a strict subset of traced input/output
```
now you get
```
Cannot export model which references tensors that are neither
buffers/parameters/constants nor are direct inputs. For each tensor, if you'd
like this tensor to be an explicit input, add it as a dummy argument
to the top-level model definition you are exporting; if you would
like its value to be embedded as an exported constant, wrap its access
in a function marked with @assume_constant_result.
G['bulbous_bouffant'], accessed at:
File "test_export.py", line N, in f
return bulbous_bouffant + y
```
This doesn't handle outputs, I'm going to hit that next.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106403
Approved by: https://github.com/tugsbayasgalan
Since Python 3.11 bytecode contains endline and column information, for each bytecode, we attribute the source code corresponding to the bytecode in a more accurate way. For example, we can highlight a function call in a series of nested function calls, or highlight a function call spanning multiple lines.
Sample:
```python
import torch
import torch._dynamo
from functorch.experimental.control_flow import cond
def h(x):
return x * 5
def true_fn(x):
return x * 2
def false_fn(x):
return x * 3
def f(pred, x):
x = h(
h(h(x))
)
x = x[1:][:2]
torch._dynamo.graph_break()
x = cond(pred, true_fn, false_fn, [x])
opt_f = torch.compile(f, backend="eager")
opt_f(torch.tensor(True), torch.randn(3, 3, 3, 3))
```
Output:
```
$ TORCH_LOGS="trace_call" python playground9.py
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
h(h(x))
~^^^
TRACE FX call mul from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
return x * 5
~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
h(h(x))
~^^^^^^
TRACE FX call mul_1 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
return x * 5
~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:15
x = h(
~^
h(h(x))
^^^^^^^
)
^
TRACE FX call mul_2 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
return x * 5
~~^~~
TRACE FX call getitem from f /scratch/williamwen/work/pytorch/playground9.py:18
x = x[1:][:2]
~^^^^
TRACE FX call getitem_1 from f /scratch/williamwen/work/pytorch/playground9.py:18
x = x[1:][:2]
~~~~~^^^^
TRACE inlined call true_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
x = cond(pred, true_fn, false_fn, [x])
~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE FX call mul from true_fn /scratch/williamwen/work/pytorch/playground9.py:9 (inline depth: 1)
return x * 2
~~^~~
TRACE inlined call false_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
x = cond(pred, true_fn, false_fn, [x])
~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE FX call mul from false_fn /scratch/williamwen/work/pytorch/playground9.py:12 (inline depth: 1)
return x * 3
~~^~~
TRACE FX call cond from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
x = cond(pred, true_fn, false_fn, [x])
~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104676
Approved by: https://github.com/ezyang
Prototype for the feature request:
>When working on a codebase that is unfamiliar to you, it can be helpful to single step through all of the code to see what is getting executed, what conditional branches are taken, and where indirect function jumps go. Model x-ray uses dynamo to give you a single step log of every source code line that does something relevant (i.e., a Tensor operation)
Dynamo logs to the ~`starts_line`~ `trace_source` logging artifact at the start of tracing new bytecode with a new line. It logs the line of source code associated with that bytecode.
~~Dynamo logs to the `graph_source` logging when a FX GraphModule is constructed. For each node in the graph, it logs the location of the original source code associated with that node.~~
Development notes: https://docs.google.com/document/d/1LjFeHzCgDDt535QUq5HydcQs56d7jWl5RvW8TLZN19g/edit?usp=sharing
Since the draft, we removed the `graph_source` logging artifact since printing the code of `GraphModule`s already displays the original source.
Sample:
```python
import torch
from functorch.experimental.control_flow import cond
def true_fn(x):
return x * 2
def false_fn(x):
return x * 3
def f_cond(pred, x):
return cond(pred, true_fn, false_fn, [x])
def f_outer(pred, x):
y = f_cond(pred, x)
if x.sum() > 0:
x = x * 2
else:
x = x * 3
return x, y
opt_f_cond = torch.compile(f_outer, backend="eager")
opt_f_cond(torch.tensor(True), torch.randn(3, 3))
```
Logs:
```shell
$ TORCH_LOGS="trace_source" python playground8.py
TRACE starts_line f_outer playground8.py:54
def f_outer(pred, x):
TRACE starts_line f_outer playground8.py:55
y = f_cond(pred, x)
TRACE starts_line f_cond playground8.py:51 (inline depth: 1)
def f_cond(pred, x):
TRACE starts_line f_cond playground8.py:52 (inline depth: 1)
return cond(pred, true_fn, false_fn, [x])
TRACE starts_line true_fn playground8.py:45 (inline depth: 2)
def true_fn(x):
TRACE starts_line true_fn playground8.py:46 (inline depth: 2)
return x * 2
TRACE starts_line false_fn playground8.py:48 (inline depth: 2)
def false_fn(x):
TRACE starts_line false_fn playground8.py:49 (inline depth: 2)
return x * 3
TRACE starts_line f_outer playground8.py:56
if x.sum() > 0:
TRACE starts_line <resume in f_outer> playground8.py:56
if x.sum() > 0:
TRACE starts_line <resume in f_outer> playground8.py:57
x = x * 2
TRACE starts_line <resume in f_outer> playground8.py:60
return x, y
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104013
Approved by: https://github.com/ezyang
Fix https://github.com/pytorch/pytorch/issues/99639 by handling the case in `InliningInstructionTranslator`'s `LOAD_CLOSURE` definition when the requested cell is not in `self.closure_cells`.
My intuition is that the behavior of `LOAD_DEREF` and `STORE_DEREF` on a cell/freevar should not depend on whether or not we called `LOAD_CLOSURE` (that is, we shouldn't create a new cell var in `LOAD_CLOSURE` like in https://github.com/pytorch/pytorch/pull/101357). But we need a way to push cells created by the inlined function that were not present in the caller - `InlinedClosureVariable` is used to differentiate these cells from other cells.
Adding this test causes an error though (EDIT: this test is not relevant to this PR and instead just reveals that `cond` with Python side effects is still broken):
```python
def test_closure_out_of_scope_cell_with_cond(self):
from functorch.experimental.control_flow import cond
cell1 = torch.rand(3, 3)
cell2 = torch.rand(3, 3)
orig3 = torch.rand(3, 3)
def test(x):
cell3 = orig3.clone()
def then():
nonlocal cell3
cell3 += cell1
return cell3
def els():
nonlocal cell3
cell3 += cell2
return cell3
return cond(x > 0, then, els, [])
opt_fn = torch._dynamo.optimize("eager")(test)
result1 = opt_fn(1)
self.assertTrue(torch.allclose(result1, orig3 + cell1))
result2 = opt_fn(-1)
self.assertTrue(torch.allclose(result1, orig3 + cell1 + cell2))
```
```
Traceback (most recent call last):
File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1768, in test_closure_out_of_scope_cell_with_cond
result1 = opt_fn(1)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 295, in _fn
return fn(*args, **kwargs)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/eval_frame.py", line 448, in catch_errors
return callback(frame, cache_size, hooks, frame_state)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 526, in _convert_frame
result = inner_convert(frame, cache_size, hooks, frame_state)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 127, in _fn
return fn(*args, **kwargs)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 360, in _convert_frame_assert
return _compile(
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/utils.py", line 180, in time_wrapper
r = func(*args, **kwargs)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 430, in _compile
out_code = transform_code_object(code, transform)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
transformations(instructions, code_options)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/convert_frame.py", line 415, in transform
tracer.run()
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2029, in run
super().run()
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run
and self.step()
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step
getattr(self, inst.opname)(inst)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 391, in wrapper
return inner_fn(self, inst)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 1100, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 559, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1061, in call_function
(false_r, false_graph, false_lifted_freevars) = speculate_branch(False)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 1044, in speculate_branch
ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph(
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/torch.py", line 850, in speculate_subgraph
output = f.call_function(tx, args, {})
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/functions.py", line 121, in call_function
return tx.inline_user_function_return(
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 595, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2134, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 2231, in inline_call_
tracer.run()
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 708, in run
and self.step()
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 668, in step
getattr(self, inst.opname)(inst)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/symbolic_convert.py", line 162, in impl
self.push(fn_var.call_function(self, self.popn(nargs), {}))
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/variables/builtin.py", line 497, in call_function
proxy = tx.output.create_proxy(
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 345, in create_proxy
return self.current_tracer.create_proxy(*args, **kwargs)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1109, in create_proxy
new_arg = self.lift_tracked_freevar_to_input(arg)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1226, in lift_tracked_freevar_to_input
self.parent.lift_tracked_freevar_to_input(proxy)
File "/scratch/williamwen/work/pytorch2/torch/_dynamo/output_graph.py", line 1219, in lift_tracked_freevar_to_input
assert (
AssertionError: lift_tracked_freevar_to_input on root SubgraphTracer
from user code:
File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1766, in test
return cond(x > 0, then, els, [])
File "/scratch/williamwen/work/pytorch2/test/dynamo/test_misc.py", line 1764, in els
cell3 += cell2
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104222
Approved by: https://github.com/jansel
Fixes#103613.
A requirement for HigherOrderOperators is that after Dynamo capture, the body
function should be functional (i.e. has no observable side effects).
If the body function mutates a variable that is not local to the body, then we
that should induce a graph break.
This PR distinguish between MutableLocals created inside/outside body
and adds relevant checks. (Design originally proposed by voznesenskym.)
- We tag each mutable_local with an id that corresponds to where it came
from. The mutable_local may represent an existing object that gets
tracked by Dynamo or an object that is created while Dynamo is
introspecting.
- This id changes when we are introspecting the body of a HigherOrderOperator.
- If Dynamo wants to perform a side effect using a mutable_local, we
check its .scope field with the current scope id and raise Unsupported
in the desired case (non-local mutation inside HigherOrderOperator body)
- The id is a global thread_local variable. I can make this not a global
variable, but it just takes some engineering time to thread a number through
each of the various ways Dynamo can construct a mutable_local.
Test Plan:
- Add a bunch of new tests. Tests combinations of {global, nonlocal} x
{number, Tensor, list, object, nn.Module} and asserts that HigherOrderOp
falls back on those cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104077
Approved by: https://github.com/voznesenskym, https://github.com/jansel
Added two signpost_event calls to torch.fx.experimental.symbolic_shapes, one for produce_guards (where we can give stats like how many free symbols and how many guards produced) and the other is for evaluate_expr after freeze (so we can look for cases where we're improperly discarding guards in backwards.)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103882
Approved by: https://github.com/Skylion007