Tracing `torch.backends.cudnn.is_acceptable(Tensor) -> bool:` fails with:
```
...
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/variables/functions.py", line 196, in call_function
return super(UserFunctionVariable, self).call_function(tx, args, kwargs)
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/variables/functions.py", line 67, in call_function
return tx.inline_user_function_return(
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 426, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 1698, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 1752, in inline_call_
tracer.run()
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 485, in run
and self.step()
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 455, in step
getattr(self, inst.opname)(inst)
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 281, in wrapper
return inner_fn(self, inst)
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 912, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 389, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/variables/torch.py", line 431, in call_function
tensor_variable = wrap_fx_proxy(
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/variables/builder.py", line 662, in wrap_fx_proxy
return wrap_fx_proxy_cls(
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/variables/builder.py", line 820, in wrap_fx_proxy_cls
raise AssertionError(
AssertionError: torch.* op returned non-Tensor bool call_function <function is_acceptable at 0x7f00deefb790>
```
So instead, evaluate `is_acceptable()` and convert the result to a constant. The result of `is_acceptable(tensor) -> bool` depends on:
* dtype/device of the input tensor (this should already be guarded)
* properties of the build & whether cudnn is available
* some global state that gets initialized during the first call to `torch.backends.cudnn._init()` (this is NOT guarded in this PR)
Note: this fixes tts_angular with FSDP. This was an issue with FSDP because FSDP modules are interpreted as UnspecializedNNModules, and UnspecializedNNModules try to inline calls. In comparison, NNModules (e.g. when the tts_angular model is not wrapped in FSDP) do not inline calls and instead evaluate subsequent calls. In subsequent calls, cudnn.is_acceptable would be skipped by eval_frame.py:catch_errors because it is not in an allowlist.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90323
Approved by: https://github.com/jansel
The original implementation of cond() operator support in dynamo operated by recursively calling export() on the inner subgraph. This is problematic for a number of reasons:
* My original motivating reason: the original implementation had to play tricks to feed real tensors to the recursive export call, which means that it doesn't work well with tracing with dynamic shapes (where we MUST stay in fake tensors to accurately track dynamic shapes across the cond invocation)
* If there are pending side effects, the recursive export() call won't see those side effects (as they are only tracked by Dynamo, not actually applied to the Python environment.) You can see an example where dynamo cond tracing does the wrong thing at https://github.com/pytorch/pytorch/pull/90208
* If there were side effects inside the true/false branch, these side effects were silently lost (as the export only returns the graph of tensor operations, and not any of the residual Python bytecodes necessary to reapply any side effects.) This could have substantive effects on the export of subsequent parts of the model, as those parts of the models could rely on the side effects.
* It was not possible to track NN module accesses inside the true/false branches, necessitating a hack where the NN module was explicitly passed in as an input to cond https://github.com/pytorch/pytorch/pull/87020#issuecomment-1338842844 which doesn't really make any sense from a backend compilation perspective
* Guards induced from the inside of the true/false branch were not properly propagated to the top level guards; they were just silently dropped (in fact, the original implementation checked that the true/false branch produce the same guards which... is not useful? Like, I don't think that actually is even necessary for correctness)
This PR replaces the old implementation with a new implementation based on graphstate checkpointing. The basic idea is to process a cond(), we checkpoint the state of our interpreter, run the true branch, rollback to our checkpoint, run the false branch, rollback to our checkpoint and then merge the changes from both of the checkpoints. I require the true/false branches to have exactly the same side effects, but union their guards.
Some of the details:
* Dynamo is too aggressive with tracking side effects when processing closures, c.f. https://github.com/pytorch/torchdynamo/pull/233/files#r1040480078 The basic problem is whenever I define a closure, this immediately counts as a side effect, even if I didn't actually mutate anything. This triggered on the nested cond export example. To prevent this from happening, I optimistically avoid tracking side effects, but if a STORE_DEREF happens, I restart analysis with the relevant Source.name() added to `mutated_closure_cell_contents` so we start tracking on closure allocation. This is enough to fix the relevant test.
* For the most part, I assert that the graph states must be equivalent after applying the true/false branches. During debugging, I found it useful to be able to compare two graph states and give a better description about what the divergence was. You can test this using the `diff()` method I've added to a few structures.
* The implementation now supports NestedUserFunctionVariable, which is nice as it allows the true/false branches to be defined closer to the cond implementation.
* I fixed the naming of the true/false subgraphs; previously they were named `name_0`, `name_1`, now they are named `cond_true_0` and `cond_false_0`
* I added `name_to_input` to the saved graph state. I don't actually know if this is necessary, but it seemed like a good idea.
* I have to play some tricks to get the speculating execution of the true/false branch to record into a subgraph. After a careful read of OutputGraph, I found that what would work is overriding graph with a fresh Graph that we want to write things into, and manually setting up the inputs/outputs. It's a little delicate as you have to make sure you reset the Graph to its original before you restore a checkpoint, as checkpoints don't actually save graph for efficiency, and just undo changes on the graph. This capability may usefully get refactored to OutputGraph but I didn't do it in this PR for simplicity.
There are some further problems with the cond() implementation that I leave for future work. Most of these were preexisting with the original implementation.
* Not a problem per se, but if an NN module is used by both the true/false branch, it will show up in the final graph twice (since it has to be a submodule of the GraphModule that makes use of it.) I hope the export pipeline can deal with this.
* List of tensor output for cond is not supported.
* The true/false return values may not have consistent sizes/dims/etc, and we don't check them for consistency.
* If we modify fake tensors in the true/false branches, we aren't rolling them back, c.f. https://github.com/pytorch/torchdynamo/issues/1840
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90286
Approved by: https://github.com/voznesenskym
The current cond implementation is silently incorrect when
there are outstanding side effects, since the locally tracked
side effects are lost when the recursive export call is made.
At least we raise an assert now.
I'm working on a refactor of cond which should be able to sidestep
this problem. Maybe.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision: [D41746973](https://our.internmc.facebook.com/intern/diff/D41746973)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90208
Approved by: https://github.com/voznesenskym
This is a group of bug fixes for [7k github models](https://github.com/pytorch/torchdynamo/issues/1884), it would fix 30+ model tests.
* Support ```tensor.type()```.
* Support ```tensor.get_device()```.
* Support ```torch.nn.functional._Reduction.get_enum```.
* Support ```torch._utils._get_device_index()```.
* Fallback ```tensor.data_ptr()```.
* ```FakeTensor``` always returns 0
* For no fake tensor propagation, we ```clone``` the input tensor, which makes no sense to track the original ```data_ptr```. And I don't think this is a very popular API.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89486
Approved by: https://github.com/jansel
Make mutation faster to speed up tracing optimizers, helps with https://github.com/pytorch/torchdynamo/issues/1803
`replace_all` no longer iterates over the entire variable tracker data structure every time a mutation is performed
Each variable tracker internally keeps a set of contained mutable variable trackers, to provide a hint to `replace_all`. This is populated with a call to `apply` from `__post_init__` in the base `VariableTracker`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89170
Approved by: https://github.com/jansel
Fixes error from 7k github models: https://github.com/jansel/pytorch-jit-paritybench/blob/master/generated/test_arashwan_matrixnet.py
Error:
```
AssertionError: torch.* op returned non-Tensor bool call_function <function is_tensor at 0x7fca94d0faf0>
from user code:
File "/scratch/ybliang/work/repos/pytorch-jit-paritybench/generated/test_arashwan_matrixnet.py", line 749, in scatter
return scatter_map(inputs)
File "/scratch/ybliang/work/repos/pytorch-jit-paritybench/generated/test_arashwan_matrixnet.py", line 741, in scatter_map
assert not torch.is_tensor(obj), 'Tensors not supported in scatter.'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88704
Approved by: https://github.com/jansel
**Introduces symbolic shape guards into dynamo.**
In this PR, we take the existing fake tensor infra and plumbing in dynamo and we start passing a shape_env around. This shape_env does not get plumbed down to middle layers / backend yet - it only collects expressions from frontend invocations at the moment. We then translate these expressions into guards at the point where we take other guards installed throughout dynamo - and add them to check_fn.
Part 1 of https://docs.google.com/document/d/1QJ-M4zfMkD-fjHIqW089RptjLl9EgozZGCceUbvmgfY/edit#
cc @jansel @lezcano @fdrocha @mlazos @soumith @yanboliang @penguinwu @anijain2305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87570
Approved by: https://github.com/ezyang
Right now, example_value is doing two jobs:
- We use it to propagate metadata (e.g. return type, shapes, etc.)
throughout the graph
- We use it to satisfy queries for the actual value (e.g. torch.cond,
`assume_constant_result`)
This is further complicated by the fact that we have two modes, one
where `example_value` is a fake tensor, and one where it is a real
tensor (this is the `fake_tensor_propagation` config flag).
This leads to scenarios where we don't support every combination of
job + mode,
e.g. if `fake_tensor_propagation=False`, `assume_constant_result` is
broken.
This is made worse by the fact that "fake tensor mode" is the default
and is required if you want dynamic shapes to work.
So, this PR introduces a `get_real_value` API that just runs the graph
up to `node` in order to get a concrete value. This API is orthogonal
to
`example_value`, so it doesn't care about `fake_tensor_propagation`.
When `fake_tensor_propagation=True`: `example_value` is a fake tensor,
you must use the `get_real_value` API to get a concrete value. This
will
be the only configuration in the future.
When `fake_tensor_propagation=False`: `example_value` and
`get_real_value` will produce the same value. This is redundant but we
will be removing this config soon.
To support this, I introduce a cache for computed real values, to
memoize the work involved if we're asking for real values a lot.
I attached this state to `OutputGraph` because it seems to be what
historically managed `example_value` lifetimes, but idk.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87091
Approved by: https://github.com/wconstab