It's part of the effort to improve PT2 Export UX. This PR is to improve the usability of `torch.cond()` by separating user errors from the dynamo internal errors. By definition, user error means the usage of `torch.cond()` violates the restrictions of this API therefore needs users to take action and fix the error.
In this notebook N3363227 we discovered a bunch of limitations of using `torch.cond(pred, true_fn, false_fn, operands)`. In summary, the limitations can be categorized as:
- predicate restriction (`pred`)
- operands restriction (`operands`)
- branch restriction (`true_fn` & `false_fn`)
The error message will be more accurate about where the (user) error is from and more actionable for users to fix it.
For example, `operands` must be a list of tensors and the signature of `true_fn` and `false_fn` must match with the `operands`.
If the operands contains non-tensor types, user will see error message like:
```
torch._dynamo.exc.UserError: Expected a list of tensors but got ["<class 'torch.Tensor'>", "<class 'float'>"]
from user code:
File "~/pytorch/test/dynamo/test_export.py", line 2504, in f_non_tensor_operands
return cond(True, lambda x, a: x.sin(), lambda x, a: x.cos(), [x, a])
```
If the signature of the branch function doesn't match with `operands`, user will see error message like:
```
torch._dynamo.exc.UserError: too many positional arguments.
func = 'false_fn' ~/pytorch/test/dynamo/test_export.py:2514, args = [<class 'torch.Tensor'>, <class 'torch.Tensor'>], kwargs = {}
```
Or if the tensor returned from user defined branches has different metadata, e.g. shapes, dtypes, etc., user will see error message like:
```
TypeError: Expected each tensor to have same metadata but got:
cond_true_0 returns TensorMetadata(shape=torch.Size([2, 1]), dtype=torch.int64, requires_grad=False, stride=(1, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})
cond_false_0 returns TensorMetadata(shape=torch.Size([1]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={})
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98909
Approved by: https://github.com/jansel
It's part of the effort to improve PT2 Export UX. This PR is to improve the usability of `torch.cond()` by allowing user to set `pred` as `ConstantVariable` as it's not often to see control flow on rank or a tensor or dim size which is traced as `ConstantVariable`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98900
Approved by: https://github.com/jansel
Summary:
Replace _dynamo.config with an object instead of module
Current usage patterns of setting and reading fields on config will work
unchanged.
Only changes needed going forward:
1. import torch._dynamo.config will not work. However, just doing
import torch._dynamo is sufficient to access dynamo config
as torch._dynamo.config.
2. Files inside of _dynamo folder need to access config via
from torch._dynamo.config_util import config instead of
from torch._dynamo import config. Because _dynamo/__init__.py
imports some of the files so it would be circular import.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96455
Approved by: https://github.com/williamwen42
### Overview
This PR de-duplicates graph inputs in TorchDynamo, using the `Source` as the unique identifier for each input. This closes https://github.com/pytorch/pytorch/issues/98743 and https://github.com/pytorch/pytorch/issues/98625.
### Details
`VariableBuilder.wrap_tensor()` should return a `VariableTracker` for the passed-in `value: Tensor`. If `value` is duplicated, we should avoid calling `OutputGraph.create_graph_input()` and `OutputGraph.add_grapharg()`.
- Note that `create_graph_input()` and `add_grapharg()` are not 1:1. For a constant source and either `wrap_sym()` or `wrap_unspecialized_primitive()`, TorchDynamo still calls `create_graph_input()` but not `add_grapharg()`.
- Note that `create_graph_input()` should be called before constructing the corresponding `VariableTracker`. TorchDynamo needs the `fx.Proxy` object to pass to `wrap_fx_proxy()`.
In this PR, the `OutputGraph` saves an additional mapping `input_source_to_var` from each graph input's `Source` to its `VariableTracker`, which works because `Source` is now hashable. This mapping should be updated each time `create_graph_input()` is called. However, since we must construct the `VariableTracker` after `create_graph_input()` returns, we must have a separate call to the `OutputGraph` to update the mapping.
If anyone has any suggestion on how to coalesce this logic and avoid having to remember to update `input_source_to_var` for each `create_graph_input()`, I would love to hear it.
<details>
<summary> Alternate Approach</summary>
Initially, I tried having TorchDynamo construct a new but equivalent `VariableTracker` for the duplicated tensor. However, I abandoned this approach after hitting an assertion in `def wrap_fx_proxy_cls()` due to `"example_value"` already being in the proxy node's metadata because we were reusing the primary tensor's `Proxy` object. Reusing the exact `VariableTracker` also seems less error-prone instead of requiring constructing a new but identical `VariableTracker`.
</details>
### Testing
#### Global Variable Test
```
import torch
@torch.compile()
def f():
return x + x
x = torch.randn(3)
f()
```
Before:
```
====== Forward graph 0 ======
<eval_with_key>.6 class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[3], arg1_1: f32[3]):
# File: /data/users/ezyang/b/pytorch/ff.py:5, code: return x + x
add: f32[3] = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
return (add,)
```
After (only `arg0_1` and no more `arg1_1`):
```
====== Forward graph 0 ======
<eval_with_key>.4 class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[3]):
# File: dynamo/test_dup_global.py:8, code: return x + x
add: f32[3] = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None
return (add,)
```
#### FSDP Test
Before we error on
```
File "/.../pytorch/torch/_guards.py", line 244, in __post_init__
assert self.input_source_a != self.input_source_b
```
and now there is no error.
---
The rename from `name_to_input` to `input_name_to_proxy` is not part of the core logic change and is a remnant from initial attempts. I can undo it later if desired, but I also feel that the new name is more informative. It also fixes the type annotation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98775
Approved by: https://github.com/ezyang, https://github.com/voznesenskym
Twice this week I have had people confuse "operator defined with Python
operator registration aka torch.library" and "PyOperator which is used
to define control flow operators and other operators that cannot be
represented in JIT schema." Renaming PyOperator for clarity.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97493
Approved by: https://github.com/SherlockNoMad
The purpose of this PR is to remove reliance on argument positions in dedup guards, AND extend the functionality to params.
A version of this PR was stamped prior https://github.com/pytorch/pytorch/pull/95831 - but was kinda gross, because it was based on an underlying PR that did way too much with source names.
This PR leaves most of that alone, in favor of just reusing the same name standardization logic that dynamo module registration does.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96774
Approved by: https://github.com/ezyang
OK, so this PR used to be about reducing the number of constants we specialize on, but it turns out that unspecialization was ~essentially never used (because we still constant specialized way too aggressively) and I ended up having to fix a bunch of issues to actually get tests to pass. So this PR is now "make int unspecialization actually work". As part of this, I have to turn off unspecialization by default, as there are still latent bugs in inductor.
The general strategy is that an unspecialized int is represented as a SymInt. Representing it as a 0d tensor (which is what the code used to do) is untenable: (1) we often need unspecialized ints to participate in size computations, but we have no way of propagating sympy expressions through tensor compute, and (2) a lot of APIs work when passed SymInt, but not when passed a Tensor. However, I continue to represent Numpy scalars as Tensors, as they are rarely used for size computation and they have an explicit dtype, so they are more accurately modeled as 0d tensors.
* I folded in the changes from https://github.com/pytorch/pytorch/pull/95099 as I cannot represent unspecialized ints as SymInts without also turning on dynamic shapes. This also eliminates the necessity for test_unspec.py, as toggling specialization without dynamic shapes doesn't do anything. As dynamic shapes defaults to unspecializing, I just deleted this entirely; for the specialization case, I rely on regular static shape tests to catch it. (Hypothetically, we could also rerun all the tests with dynamic shapes, but WITH int/float specialization, but this seems... not that useful? I mean, I guess export wants it, but I'd kind of like our Source heuristic to improve enough that export doesn't have to toggle this either.)
* Only 0/1 integers get specialized by default now
* A hodgepodge of fixes. I'll comment on the PR about them.
Fixes https://github.com/pytorch/pytorch/issues/95469
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95621
Approved by: https://github.com/jansel, https://github.com/Chillee
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
As found in #92709, thanks to @ngimel and @jansel, currently `torch.Tensor.fn` points to `UserDefinedObjectVariable` rather than `TorchVariable`. The root cause is due to https://github.com/pytorch/pytorch/pull/92709#pullrequestreview-1273357406. To prevent this, build `TorchVariable` of `torch.Tensor.fn` pointing to `torch.ops.aten.fn`.
This issue propagates to `torch.Tensor.fn` causing graph break with `nopython=True`.
```python
import torch
import torch._dynamo as dynamo
#op = torch.ops.aten.abs_ # no graph break
op = torch.Tensor.abs_ # graph break
args = torch.empty(10)
def foo(args):
return op(args)
opt_foo = dynamo.optimize("inductor", nopython=True)(foo)
y_ = opt_foo(args)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93243
Approved by: https://github.com/jansel
# Summary
This PR creates _flash_attention_backward and _scaled_dot_product_flash_attention_backward native functions and registers them to the respective derivatives.yaml.
The goal is to replicate the torch.autograd.Function defined in the FlashAttention repo [here](33e0860c9c/flash_attn/flash_attn_interface.py (L126)) natively in PyTorch. One thing that we don't have access to is ctx.save_for_backward in native PyTorch so in order to save these variables I extended the returned objects from the forward functions.
### MetaFunctions
I also updated the FlashAttention meta functions to mirror the real outputs now. As well I added a meta registration for backwards. I have an XLMR training script and while eager training now works with FlashAttention compiling this module fails with the inductor error down below.
### Questions?
Performance issues vs mem efficient when using torch.nn.mha_forward
TorchCompile -> See purposed solution below.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92917
Approved by: https://github.com/cpuhrsch
Tracing `torch.backends.cudnn.is_acceptable(Tensor) -> bool:` fails with:
```
...
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/variables/functions.py", line 196, in call_function
return super(UserFunctionVariable, self).call_function(tx, args, kwargs)
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/variables/functions.py", line 67, in call_function
return tx.inline_user_function_return(
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 426, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 1698, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 1752, in inline_call_
tracer.run()
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 485, in run
and self.step()
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 455, in step
getattr(self, inst.opname)(inst)
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 281, in wrapper
return inner_fn(self, inst)
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 912, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/symbolic_convert.py", line 389, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/variables/torch.py", line 431, in call_function
tensor_variable = wrap_fx_proxy(
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/variables/builder.py", line 662, in wrap_fx_proxy
return wrap_fx_proxy_cls(
File "/scratch/dberard/dynamo38/pytorch/torch/_dynamo/variables/builder.py", line 820, in wrap_fx_proxy_cls
raise AssertionError(
AssertionError: torch.* op returned non-Tensor bool call_function <function is_acceptable at 0x7f00deefb790>
```
So instead, evaluate `is_acceptable()` and convert the result to a constant. The result of `is_acceptable(tensor) -> bool` depends on:
* dtype/device of the input tensor (this should already be guarded)
* properties of the build & whether cudnn is available
* some global state that gets initialized during the first call to `torch.backends.cudnn._init()` (this is NOT guarded in this PR)
Note: this fixes tts_angular with FSDP. This was an issue with FSDP because FSDP modules are interpreted as UnspecializedNNModules, and UnspecializedNNModules try to inline calls. In comparison, NNModules (e.g. when the tts_angular model is not wrapped in FSDP) do not inline calls and instead evaluate subsequent calls. In subsequent calls, cudnn.is_acceptable would be skipped by eval_frame.py:catch_errors because it is not in an allowlist.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90323
Approved by: https://github.com/jansel
The original implementation of cond() operator support in dynamo operated by recursively calling export() on the inner subgraph. This is problematic for a number of reasons:
* My original motivating reason: the original implementation had to play tricks to feed real tensors to the recursive export call, which means that it doesn't work well with tracing with dynamic shapes (where we MUST stay in fake tensors to accurately track dynamic shapes across the cond invocation)
* If there are pending side effects, the recursive export() call won't see those side effects (as they are only tracked by Dynamo, not actually applied to the Python environment.) You can see an example where dynamo cond tracing does the wrong thing at https://github.com/pytorch/pytorch/pull/90208
* If there were side effects inside the true/false branch, these side effects were silently lost (as the export only returns the graph of tensor operations, and not any of the residual Python bytecodes necessary to reapply any side effects.) This could have substantive effects on the export of subsequent parts of the model, as those parts of the models could rely on the side effects.
* It was not possible to track NN module accesses inside the true/false branches, necessitating a hack where the NN module was explicitly passed in as an input to cond https://github.com/pytorch/pytorch/pull/87020#issuecomment-1338842844 which doesn't really make any sense from a backend compilation perspective
* Guards induced from the inside of the true/false branch were not properly propagated to the top level guards; they were just silently dropped (in fact, the original implementation checked that the true/false branch produce the same guards which... is not useful? Like, I don't think that actually is even necessary for correctness)
This PR replaces the old implementation with a new implementation based on graphstate checkpointing. The basic idea is to process a cond(), we checkpoint the state of our interpreter, run the true branch, rollback to our checkpoint, run the false branch, rollback to our checkpoint and then merge the changes from both of the checkpoints. I require the true/false branches to have exactly the same side effects, but union their guards.
Some of the details:
* Dynamo is too aggressive with tracking side effects when processing closures, c.f. https://github.com/pytorch/torchdynamo/pull/233/files#r1040480078 The basic problem is whenever I define a closure, this immediately counts as a side effect, even if I didn't actually mutate anything. This triggered on the nested cond export example. To prevent this from happening, I optimistically avoid tracking side effects, but if a STORE_DEREF happens, I restart analysis with the relevant Source.name() added to `mutated_closure_cell_contents` so we start tracking on closure allocation. This is enough to fix the relevant test.
* For the most part, I assert that the graph states must be equivalent after applying the true/false branches. During debugging, I found it useful to be able to compare two graph states and give a better description about what the divergence was. You can test this using the `diff()` method I've added to a few structures.
* The implementation now supports NestedUserFunctionVariable, which is nice as it allows the true/false branches to be defined closer to the cond implementation.
* I fixed the naming of the true/false subgraphs; previously they were named `name_0`, `name_1`, now they are named `cond_true_0` and `cond_false_0`
* I added `name_to_input` to the saved graph state. I don't actually know if this is necessary, but it seemed like a good idea.
* I have to play some tricks to get the speculating execution of the true/false branch to record into a subgraph. After a careful read of OutputGraph, I found that what would work is overriding graph with a fresh Graph that we want to write things into, and manually setting up the inputs/outputs. It's a little delicate as you have to make sure you reset the Graph to its original before you restore a checkpoint, as checkpoints don't actually save graph for efficiency, and just undo changes on the graph. This capability may usefully get refactored to OutputGraph but I didn't do it in this PR for simplicity.
There are some further problems with the cond() implementation that I leave for future work. Most of these were preexisting with the original implementation.
* Not a problem per se, but if an NN module is used by both the true/false branch, it will show up in the final graph twice (since it has to be a submodule of the GraphModule that makes use of it.) I hope the export pipeline can deal with this.
* List of tensor output for cond is not supported.
* The true/false return values may not have consistent sizes/dims/etc, and we don't check them for consistency.
* If we modify fake tensors in the true/false branches, we aren't rolling them back, c.f. https://github.com/pytorch/torchdynamo/issues/1840
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90286
Approved by: https://github.com/voznesenskym
The current cond implementation is silently incorrect when
there are outstanding side effects, since the locally tracked
side effects are lost when the recursive export call is made.
At least we raise an assert now.
I'm working on a refactor of cond which should be able to sidestep
this problem. Maybe.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision: [D41746973](https://our.internmc.facebook.com/intern/diff/D41746973)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90208
Approved by: https://github.com/voznesenskym
This is a group of bug fixes for [7k github models](https://github.com/pytorch/torchdynamo/issues/1884), it would fix 30+ model tests.
* Support ```tensor.type()```.
* Support ```tensor.get_device()```.
* Support ```torch.nn.functional._Reduction.get_enum```.
* Support ```torch._utils._get_device_index()```.
* Fallback ```tensor.data_ptr()```.
* ```FakeTensor``` always returns 0
* For no fake tensor propagation, we ```clone``` the input tensor, which makes no sense to track the original ```data_ptr```. And I don't think this is a very popular API.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89486
Approved by: https://github.com/jansel