Summary: ONNX team and recent transformer upgrade ran into this error and we also ran into during our export benchmarking. This diff makes it possible to trace through vmap implementation in pre-dispatch IR. Note that we don't support serializing functorch ops in pre-dispatch IR and in the future, we should desugar them to post-grad ops.
The implementation strategy is:
1. We add python wrappers around vmap APIs so that we attach custom torch function handler that is only on during non-strict export. The reason is we don't want to add this to default torch_function handler because it will break BC.
2. Some dynamo changes to make sure it picks up new python wrapper APIs. The reason is when we do strict export, we need to re-materialize these APIs in pre-dispatch IR from torch IR. We can avoid this by special casing in dynamo for export to proxy different API calls but i feel that is too much chaos because you need to be able to proxy 2 different variants of same vmap API.
Test Plan: CI
Differential Revision: D75623875
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154650
Approved by: https://github.com/ezyang, https://github.com/zou3519
Dynamo was aggressively specializing on lazy VTs over `set_name_hint` in
`STORE_FAST`, etc., and `isinstance` in `LOAD_FAST_CHECK`. This causes
regional `torch.compile` from optimizing ComfyUI GGUF + LoRA to either
(1). exceed the recompialtion limit of 8, which results in suboptimal
performance, and (2). even if recompilation limit is increased, the
compilation time gets unnecessarily high (180s v.s. 20s for Flux).
This patch fixes the recompilation issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156891
Approved by: https://github.com/williamwen42, https://github.com/mlazos
Usage:
```python
from torch._higher_order_ops.wrap import dynamo_bypassing_wrapper
# Your ordinary function wrapper
def my_hop_fn_impl(fn, *args, k=1, **kwargs):
def wrapper(*args, **kwargs):
out = fn(*args, **kwargs)
if isinstance(out, tuple):
return (out[0] + k,)
return out + k
return wrapper
# Calling `my_hop_fn` instead of the impl directly captures a HOP into the dynamo graph
def my_hop_fn(fn, *args, k=1, **kwargs):
return dynamo_bypassing_wrapper(
functools.partial(my_hop_fn_impl, k=k), fn, *args, **kwargs
)
```
Notes:
- The dynamo captured graph now stashes arbitrary callable objects (the wrapper_fn) - this is equivalent to what SAC does today with policy_fn.
- The `wrapper_fn` passed to `dynamo_bypassing_wrapper ` should have signature `Callable -> Callable`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153487
Approved by: https://github.com/ydwu4
This PR was inspired by internal models that were cache missing due to PGO. At a high level the problem looks as follows
Run 1, Invocation 1: We do static compile, save some example values in PGO/automatic dynamic
Run 1, Invocation 2: We detect varying inputs, do dynamic compile, get a dynamic graph and save to PGO. Crucially what we save to PGO is actually a superset of what is actually dynamic. If we notice an input was varying, we mark it as dynamic in PGO even if later on that value gets specialized. When a value gets specialized, we actually remove the symbol from the graph. This results in an interesting conundrum where although we are producing the same isomorphic graph, PGO makes the second run cache miss. Let's see how....
Run 2, Invocation 1: We fetch the PGO, over-mark things as dynamic, get a fx graph, look it up in the cache and... whoops! cache miss! This is because of the aforementioned behavior where the PGO profile will cause us to over-allocate symbols. In practice this means we end up saving a graph in cache with symbols x:s1, y:s3 and on second attempt we cache miss with x:s1, y:s6 where symbols s3,s4,s5 were all optimistically marked dynamic by PGO and subsequently specialized.
We solve this problem by hashing the source names. This ensures somewhat stable assignment. To prevent catastrophic symbol collisions, we use linear probing to ensure no collisions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149665
Approved by: https://github.com/Mingming-Ding, https://github.com/laithsakka
This PR was inspired by internal models that were cache missing due to PGO. At a high level the problem looks as follows
Run 1, Invocation 1: We do static compile, save some example values in PGO/automatic dynamic
Run 1, Invocation 2: We detect varying inputs, do dynamic compile, get a dynamic graph and save to PGO. Crucially what we save to PGO is actually a superset of what is actually dynamic. If we notice an input was varying, we mark it as dynamic in PGO even if later on that value gets specialized. When a value gets specialized, we actually remove the symbol from the graph. This results in an interesting conundrum where although we are producing the same isomorphic graph, PGO makes the second run cache miss. Let's see how....
Run 2, Invocation 1: We fetch the PGO, over-mark things as dynamic, get a fx graph, look it up in the cache and... whoops! cache miss! This is because of the aforementioned behavior where the PGO profile will cause us to over-allocate symbols. In practice this means we end up saving a graph in cache with symbols x:s1, y:s3 and on second attempt we cache miss with x:s1, y:s6 where symbols s3,s4,s5 were all optimistically marked dynamic by PGO and subsequently specialized.
We solve this problem by hashing the source names. This ensures somewhat stable assignment. To prevent catastrophic symbol collisions, we use linear probing to ensure no collisions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149665
Approved by: https://github.com/Mingming-Ding, https://github.com/laithsakka
Adds a `invoke_quant` higher order operator as proposed [here](https://docs.google.com/document/d/1s2PfJlq6Q1F8l11CkTIC69BW1rEnGEgs6YmBC7hu8rA/edit?tab=t.0).
The primary motivations are
- Unifying scattered reasoning for quant operators throughout the code base
- Easy of pattern matching - see this very large pattern match expression [here](949fdd2997/torch/_inductor/fx_passes/post_grad.py (L390-L426). Compared to the pattern I have in the tests:
```
@register_graph_pattern(
CallFunction(
torch.ops.aten.mm,
CallFunction(
torch.ops.higher_order.invoke_quant,
Ignored(),
Ignored(),
Ignored(),
scheme="nf4",
),
Arg(),
),
pass_dict=test_pass,
)
```
- Ability to specify inductor specific logic, like codegen'ing the operators in lower precision, or forcing fusion to a matmul.
Example graph:
``` Python
===== AFTER POST GRAD =====
/data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"):
# File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg]
repeated_subgraph0 = self.repeated_subgraph0
invoke_quant: "f32[8][1]cpu" = torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4'); repeated_subgraph0 = arg0_1 = arg1_1 = None
return (invoke_quant,)
class repeated_subgraph0(torch.nn.Module):
def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"):
# File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self) # type: ignore[call-arg]
mul: "f32[8][1]cpu" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = None
add: "f32[8][1]cpu" = torch.ops.aten.add.Tensor(mul, arg1_1); mul = arg1_1 = None
return add
```
The schema for `invoke_quant` is `torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=None)` where the scheme will not always be present.
I wasn't sure exactly how the inductor specific configurations like `codgen_in_low_precision` should be passed through. I didnt want to stuff them all in as kwargs, and I didn't want to have them affect pattern matching. So they will be stored as meta of the node itself. And, following that, I wanted the invocation of the hop to match how it will show up in the graph. So I decided to have it be an object that is then invoked for the tracing.
```
invoke_quant = InvokeQuant(codegen_low_precision=True)
invoke_quant(gn, (x, y), scheme="nf4")
```
Todo - not require the packing of args in a tuple, will do following https://github.com/pytorch/pytorch/pull/139162.
Feedback welcome.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139102
Approved by: https://github.com/Chillee
This patch models input cell object as "newly created" rather than
"pre-existing" python object (see added documentation for why this
actually captures the semantics more accurately).
This enables the `SideEffects.prune_dead_object_new` algorithm to prune
away writes to input cell objects which are no longer relevant; this
didn't happen prior to this patch because we modelled them as
pre-existing objects, which forces us to codegen their attribute
mutations.
Fixes#145564.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145781
Approved by: https://github.com/williamwen42, https://github.com/jansel
This basically undoes some workarounds introduced in #119926, the
root causes of which have been fixed by #142078 and other changes in
Dynamo.
Now that Dynamo traces the spec comparison code, the test also needs update:
- removing the `_jvp_treespec_compare` calls in fx graph
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142081
Approved by: https://github.com/zou3519
ghstack dependencies: #142078, #142080
This basically undoes most of the workarounds introduced in #119405, the
root causes of which have been fixed by #142078 and other changes in
Dynamo.
Now that Dynamo traces the spec comparison code, the test also needs update:
1. renaming `o` to `pimals_out`
2. removing the `_vjp_treespec_compare` calls in fx graph
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142080
Approved by: https://github.com/zou3519
ghstack dependencies: #142078
In addition to `NewCellVariable`, Dynamo has 3 ways of modeling cell objects:
1. For cells captured and created by the root frame, represent them as
their contents in `root_tx.symbolic_locals`, which `LOAD_DEREF` and
`STORE_DEREF` update directly, without going through `SideEffects`.
2. `ClosureVariable`: this is created when cells from (1) are captured
by a newly created function Dynamo is about to inline. It's a handle
with a name that redirects `LOAD_DEREF` and `STORE_DEREF` back (1),
to make `root_tx.symbolic_locals` up-to-date.
3. For cells that are captured by both the root frame and some
pre-existing function Dynamo is about to inline, represent those
cells as contents, and do not allow writes to them.
Note that (2) and (3) are mainly to conform with (1) -- to make sure
Dynamo has a consistent modeling of cells for the same cell objects.
In this patch, we represent all of these cells as `NewCellVariable`. The
main new code paths introduced are:
- using `NewCellVariable` to model cell objects created by the root
frame (the cells are passed in as input to `InstructionTranslator`),
this is what allows us to get rid of all 3 legacy paths above.
- adding a new `AutoDerefLocalSource` to deal with the python-code
level (guards) and bytecode level (codegen) auto-dereferencing
behavior, when accessing pre-existing python cells. This also
involves a tiny update to guard manager generation.
- plumbing some extra info into `LocalSource` and `CellVariable` so that
we can still emit `LOAD_DEREF`, `STORE_DEREF`, `LOAD_CLOSURE` (instead
of `make_cell`, `cell_contents` attribute access, and `LOAD_FAST`),
which is important for readability, performance, and some
assumptions `bytecode_transformation.py` makes.
As a result, this patch removes a lot of the now-dead code paths and
TODOs. Notably, it significantly simplified the `prune_dead_locals`
function, which was duplicating a lot of the logic from
`prune_dead_object_new`; this conveniently closes#137123.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140153
Approved by: https://github.com/jansel
ghstack dependencies: #140330, #140152, #140436, #140435