Before the change in this PR, we have an error for the following code
```python
import torch
torch._dynamo.config.capture_scalar_outputs = True
class M(torch.nn.Module):
def forward(self, idx, x):
u0 = idx.item()
x0 = x.select(0, u0)
def fn():
return x0.sin()
return torch.cond(x0.sum() > 0, fn, fn)
m = M()
out = torch.compile(m, fullgraph=True)(torch.tensor(0, dtype=torch.int64), torch.randn(3, 3))
```
The error is caused when speculate fn, and tries to lift symbol of x0.storage_offset() but found the symbols doesn't have a source associated with it.
What really happens is that, when input tensor is a scalar tensor of int type and resides on CPU, we have a short cut that creates a norm symint when .item() is called see https://github.com/pytorch/pytorch/pull/126245.
However, previously, we only track the unbacked symint output of an operation because we believe all the backed symint must have a source associated with it and has already bee lifted as input at the top-level. Now this invariant no longer holds, so we end up an error saying the symbol doesn't have source (because only input and symbols derided from inputs have source and result of .item() doesn't have a source).
In this PR, we start to also track the normal symint with the proxy that created it (i.e. in this case the proxy .item()).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161198
Approved by: https://github.com/zou3519
This adds a new function `bypass_package` and `CompilePackage.bypass_current_entry()`. This allows us to safely bypass if there are models with unserializable or incompatible parts. When we encounter something incompatible, we'll raise a bypass and ignore that particular code in DynamoCodeEntry.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160902
Approved by: https://github.com/zhxchen17
Summary: ONNX team and recent transformer upgrade ran into this error and we also ran into during our export benchmarking. This diff makes it possible to trace through vmap implementation in pre-dispatch IR. Note that we don't support serializing functorch ops in pre-dispatch IR and in the future, we should desugar them to post-grad ops.
The implementation strategy is:
1. We add python wrappers around vmap APIs so that we attach custom torch function handler that is only on during non-strict export. The reason is we don't want to add this to default torch_function handler because it will break BC.
2. Some dynamo changes to make sure it picks up new python wrapper APIs. The reason is when we do strict export, we need to re-materialize these APIs in pre-dispatch IR from torch IR. We can avoid this by special casing in dynamo for export to proxy different API calls but i feel that is too much chaos because you need to be able to proxy 2 different variants of same vmap API.
Test Plan: CI
Differential Revision: D75623875
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154650
Approved by: https://github.com/ezyang, https://github.com/zou3519
As title. This is a follow-up of the previous patch, with the goal of
supporting a new pattern that showed up in ComfyUI:
644b23ac0b/comfy/ops.py (L44)
Effectively, the semantics of calling a function decorated with a
context manager is:
```python
@ctx_manager(args)
def f(x):
...
f(x)
# ----->
with ctx_manager(args):
f.__wrapped__(x)
```
Yes, a fresh context manager instance per invokation, see CPython source code:
https://github.com/python/cpython/blob/3.12/Lib/contextlib.py#L119-L122
So Dynamo already
1. knows how to handle the `with ctx_manager(args)` syntax, and has
special handling for a few torch native context managers, like
`sdpa_kernel` in this patch.
2. can trace through a good chunk (at least the ones that matter in this
case) of contextlib.
This patch just let Dynamo trace a bit more into contextlib, and then
keep the torch-native special cases by moving their handling a bit down
the stack, so that no additional logic is introduced -- it's only
refactored.
This also allows us to get rid of some `_sdpa_kernel_variadic` special
handling, since now we will trace through its code, and it boils down to
`sdpa_kernel` anyways.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160703
Approved by: https://github.com/guilhermeleobas, https://github.com/mlazos
ghstack dependencies: #160684
This patch fixes 2 issues, illustrated by the test cases added:
1. using `sdpa_kernel(backends=..., set_priority=...)` due to an
internal assert that forgot to be updated after #147768.
2. forgetting to convert the `set_priority` VariableTracker back to a
python constant so that its value is properly used by `sdpa_kernel`,
also from #147768.
I ran into (1) because ComfyUI had a recent update that actually sues
this pattern
644b23ac0b/comfy/ops.py (L44),
and then noticed (2), and fixed it conveniently.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160684
Approved by: https://github.com/mlazos
Changes:
(1) Replace UserDefinedSetVariable by UserDefinedObjectVariable in all binop calls
Test plan:
(1) The three tests from CPython `test_collections.py` ensures that Dynamo can trace through a dunder method (e.g. __add__, __ixor__, etc) defined in a user defined class
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159865
Approved by: https://github.com/mlazos
ghstack dependencies: #159365, #159366, #159368, #159483, #159902, #159864
After the change, the error stacktrace is attached with user code stack and is suppressed into 1 (without the scrolling up mssage). For example:
```python
class Test(torch.nn.Module):
def forward(self, c, x):
def cond_fn(c, x):
return c > 0 and x.size(0) < 20
def body_fn(c, x):
return c - 1, x.sin()
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (c, x))
```
Now gives the following error message:
```python
Traceback (most recent call last):
File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1705, in test_while_loop_size_mismatch_tensor_expansion
self._run_test(
~~~~~~~~~~~~~~^
model=WhileLoopModels.SizeMismatchTensorExpansion(),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...<2 lines>...
dynamic=dynamic,
^^^^^^^^^^^^^^^^
)
^
File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1417, in _run_test
result = model(*inputs_with_counters)
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1053, in forward
return torch._higher_order_ops.while_loop(cond_fn, body_fn, (c, x))
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 176, in while_loop
return torch.compile(
~~~~~~~~~~~~~~
_while_loop_op_wrapper, backend=backend, fullgraph=True
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
)(flat_cond_fn, flat_body_fn, tuple(flat_inputs), tuple())
~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 804, in compile_wrapper
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1595, in __call__
result = self._torchdynamo_orig_backend(
frame, cache_entry, self.hooks, frame_state, skip=1
)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1353, in __call__
result = self._inner_convert(
frame, cache_entry, hooks, frame_state, skip=skip + 1
)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 682, in __call__
result = _compile(
frame.f_code,
...<16 lines>...
convert_frame_box=self._box,
)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 1172, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_utils_internal.py", line 98, in wrapper_function
return function(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 858, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 897, in _compile_inner
out_code = transform_code_object(code, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1461, in transform_code_object
transformations(instructions, code_options)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 300, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 818, in transform
tracer.run()
~~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3528, in run
super().run()
~~~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run
while self.step():
~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step
self.dispatch_table[inst.opcode](self, inst)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 852, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2240, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1200, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 91, in graph_break_as_hard_error
raise exc.with_traceback(sys.exc_info()[2]) from None
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 77, in graph_break_as_hard_error
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 1287, in call_function
) = speculate_subgraph(
~~~~~~~~~~~~~~~~~~^
tx,
^^^
...<33 lines>...
supports_aliasing=self.supports_aliasing,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 877, in speculate_subgraph
raise ex
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 718, in speculate_subgraph
output = f.call_function(tx, args, sub_kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 580, in call_function
return super().call_function(tx, args, kwargs)
~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1217, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3733, in inline_call
return tracer.inline_call_()
~~~~~~~~~~~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3936, in inline_call_
self.run()
~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run
while self.step():
~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step
self.dispatch_table[inst.opcode](self, inst)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 852, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2240, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1200, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/lazy.py", line 212, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 580, in call_function
return super().call_function(tx, args, kwargs)
~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 334, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1217, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3733, in inline_call
return tracer.inline_call_()
~~~~~~~~~~~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 3936, in inline_call_
self.run()
~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1372, in run
while self.step():
~~~~~~~~~^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1276, in step
self.dispatch_table[inst.opcode](self, inst)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 830, in inner
unimplemented_v2(
~~~~~~~~~~~~~~~~^
gb_type="Data-dependent branching",
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...<5 lines>...
],
^^
)
^
File "/home/yidi/local/pytorch/torch/_dynamo/exc.py", line 580, in unimplemented_v2
raise Unsupported(msg)
torch._dynamo.exc.UncapturedHigherOrderOpError: while_loop doesn't work unless it is captured completely with torch.compile. Got Data-dependent branching
Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
Hint: Use `torch.cond` to express dynamic control flow.
Developer debug context: attempted to jump with TensorVariable()
For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0170.html
from user code:
File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 167, in _while_loop_op_wrapper
return while_loop_op(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_higher_order_ops/while_loop.py", line 137, in flat_cond_fn
return cond_fn(*carried, *additional)
File "/home/yidi/local/pytorch/test/inductor/test_control_flow.py", line 1047, in cond_fn
return c > 0 and x.size(0) < 20
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
To execute this test, run the following from the base repo dir:
python test/inductor/test_control_flow.py WhileLoopTests.test_while_loop_size_mismatch_tensor_expansion_device_cpu_dynamic_False
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159296
Approved by: https://github.com/zou3519
**Summary**
This PR adds an all-gather based FlexAttention and uses TorchFunctionMode to dispatch
`FlexAttentionHOP.__call__` to it.
This PR makes the following changes:
- add a user-facing API `create_cp_block_mask` for creating CP-specific `BlockMask`
which masks over the attention result of Q shard and KV global.
- add `_ContextParallelGlobalVars` to store all necessary global vars that CP FlexAttention
requires. `torch_function_mode` is critical to maintain singleton mode to avoid dynamo
recompilations.
- add a dispatch path for `FlexAttentionForwardHOP.__call__` (TorchFunctionMode dispatch
won't work correctly without this line)
What's not in this PR:
- QKV load balancing
- Test on other masking besides `causal_mask`.
- Support on small attention (i.e. qkv size is smaller than 128) because the block mask
rewrite function requires `Q_BLOCK_SIZE == KV_BLOCK_SIZE == 128`.
**Test**
`pytest test/distributed/tensor/test_attention.py -s -k test_ring_flex_attention`
**Followup**
1. create an issue to reproduce the error in `create_fw_bw_graph()` when trying to call `create_block_mask`
to re-write `block_mask` in `FlexAttentionHOP` dispatch in `TorchFunctionMode`.
2. Merge `_ContextParallelGlobalVars` and `_cp_options`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158692
Approved by: https://github.com/drisspg