Commit Graph

249 Commits

Author SHA1 Message Date
Yuanyuan Chen
fc8ac1216c [4/N] Remove unused loop variables in tests (#166690)
This PR removes unused loop variables in tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166690
Approved by: https://github.com/justinchuby, https://github.com/mlazos
2025-10-31 10:20:48 +00:00
Sherlock Huang
34d6ef7022 Update gm.print_readable to include Annotation (#165397)
Sample output
```
[rank0]:        # Annotation: {'compile_with_inductor': 'flex_attention'} File: /data/users/bahuang/pytorch/torch/nn/attention/flex_attention.py:1490 in flex_attention, code: out, lse, max_scores = flex_attention_hop(
[rank0]:        score_mod_2 = self.score_mod_2
[rank0]:        mask_fn_2 = self.mask_fn_2
[rank0]:        flex_attention_1 = torch.ops.higher_order.flex_attention(xq_5, xk_5, xv_3, score_mod_2, (2048, 2048, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___kv_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___kv_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_kv_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_kv_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___q_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___q_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_q_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_q_indices, 128, 128, mask_fn_2), 0.25, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___mask_mod___closure___0_cell_contents,));  xq_5 = xk_5 = xv_3 = score_mod_2 = mask_fn_2 = None
[rank0]:        out_2: "bf16[8, 4, 2048, 16]" = flex_attention_1[0];  flex_attention_1 = None
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165397
Approved by: https://github.com/yushangdi, https://github.com/anijain2305, https://github.com/mlazos
2025-10-28 13:54:38 +00:00
PyTorch MergeBot
e50dc40d28 Revert "Update gm.print_readable to include Annotation (#165397)"
This reverts commit 7a65770013.

Reverted https://github.com/pytorch/pytorch/pull/165397 on behalf of https://github.com/malfet due to I don't know how/why, but it breaks windows tests, see 2e22b1a61e/1 ([comment](https://github.com/pytorch/pytorch/pull/165397#issuecomment-3417428128))
2025-10-17 22:35:50 +00:00
Sherlock Huang
7a65770013 Update gm.print_readable to include Annotation (#165397)
Sample output
```
[rank0]:        # Annotation: {'compile_with_inductor': 'flex_attention'} File: /data/users/bahuang/pytorch/torch/nn/attention/flex_attention.py:1490 in flex_attention, code: out, lse, max_scores = flex_attention_hop(
[rank0]:        score_mod_2 = self.score_mod_2
[rank0]:        mask_fn_2 = self.mask_fn_2
[rank0]:        flex_attention_1 = torch.ops.higher_order.flex_attention(xq_5, xk_5, xv_3, score_mod_2, (2048, 2048, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___kv_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___kv_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_kv_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_kv_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___q_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___q_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_q_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_q_indices, 128, 128, mask_fn_2), 0.25, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___mask_mod___closure___0_cell_contents,));  xq_5 = xk_5 = xv_3 = score_mod_2 = mask_fn_2 = None
[rank0]:        out_2: "bf16[8, 4, 2048, 16]" = flex_attention_1[0];  flex_attention_1 = None
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165397
Approved by: https://github.com/yushangdi, https://github.com/anijain2305
2025-10-17 18:35:18 +00:00
Laith Sakka
b377c9e365 graph break on tolist if capture_scalar_outputs is false (#163807)
address https://github.com/pytorch/pytorch/issues/163798

its problematic to not graph break because:
1. break current contract.
2. well dynamo trace then we have .item call then if we ever re-trace later in autograd for example we hit a
 failure (We do not know where to graph break at that point)! see the added unit test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163807
Approved by: https://github.com/bobrenjc93
2025-09-28 04:02:52 +00:00
Simon Fan
821458d97a [dynamo][hop] Introduce Local Map HOP (#161458)
Can't actually deploy it because of: https://github.com/pytorch/pytorch/issues/161456

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161458
Approved by: https://github.com/ydwu4
2025-09-17 09:32:38 +00:00
PyTorch MergeBot
e7c3f802ff Revert "[dynamo][hop] Introduce Local Map HOP (#161458)"
This reverts commit 505458db80.

Reverted https://github.com/pytorch/pytorch/pull/161458 on behalf of https://github.com/jeffdaily due to broke rocm tests ([comment](https://github.com/pytorch/pytorch/pull/161458#issuecomment-3299230458))
2025-09-16 15:14:36 +00:00
Simon Fan
505458db80 [dynamo][hop] Introduce Local Map HOP (#161458)
Can't actually deploy it because of: https://github.com/pytorch/pytorch/issues/161456

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161458
Approved by: https://github.com/ydwu4
2025-09-16 00:37:40 +00:00
Animesh Jain
6b1900c22f [dynamo][hops] Remove const outputs from the speculated subgraph (#161355)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161355
Approved by: https://github.com/zou3519
2025-09-04 18:52:01 +00:00
Tugsbayasgalan (Tugsuu) Manlaibaatar
dbef606631 Add support for tracing vmap in pre-dispatch export (#154650)
Summary: ONNX team and recent transformer upgrade ran into this error and we also ran into during our export benchmarking. This diff makes it possible to trace through vmap implementation in pre-dispatch IR. Note that we don't support serializing functorch ops in pre-dispatch IR and in the future, we should desugar them to post-grad ops.

The implementation strategy is:
1. We add python wrappers around vmap APIs so that we attach custom torch function handler that is only on during non-strict export. The reason is we don't want to add this to default torch_function handler because it will break BC.
2. Some dynamo changes to make sure it picks up new python wrapper APIs. The reason is when we do strict export, we need to re-materialize these APIs in pre-dispatch IR from torch IR. We can avoid this by special casing in dynamo for export to proxy different API calls but i feel that is too much chaos because you need to be able to proxy 2 different variants of same vmap API.

Test Plan: CI

Differential Revision: D75623875

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154650
Approved by: https://github.com/ezyang, https://github.com/zou3519
2025-08-20 19:31:07 +00:00
ghostspiders
af10f1f86c Fix requires_cuda to requires_cuda_and_triton (#160222)
Fixes ##159399

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160222
Approved by: https://github.com/janeyx99
2025-08-10 07:05:52 +00:00
gaoyvfeng
50f23ff6f8 rename-HAS_CUDA-to-HAS_CUDA_AND_TRITON (#159883)
Fixes #159399
"Modified torch.testing._internal.inductor_utils and test/inductor"

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159883
Approved by: https://github.com/janeyx99
2025-08-08 15:44:52 +00:00
Yidi Wu
da05b7fb94 [cond] add _FlopCounterMode support for cond (#158067)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158067
Approved by: https://github.com/zou3519
ghstack dependencies: #158077
2025-07-16 17:26:20 +00:00
Yidi Wu
82b1c48292 [hop] add supports_higher_order_operators flag to TorchDispatchMode (#158077)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158077
Approved by: https://github.com/zou3519
2025-07-16 17:26:20 +00:00
Ryan Guo
f742b32a2f [dynamo] Avoid recompiling over unused objects (#156891)
Dynamo was aggressively specializing on lazy VTs over `set_name_hint` in
`STORE_FAST`, etc., and `isinstance` in `LOAD_FAST_CHECK`. This causes
regional `torch.compile` from optimizing ComfyUI GGUF + LoRA to either
(1). exceed the recompialtion limit of 8, which results in suboptimal
performance, and (2). even if recompilation limit is increased, the
compilation time gets unnecessarily high (180s v.s. 20s for Flux).

This patch fixes the recompilation issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156891
Approved by: https://github.com/williamwen42, https://github.com/mlazos
2025-07-09 20:14:34 +00:00
Xuehai Pan
02715d0876 [BE][5/6] fix typos in test/ (test/dynamo/) (#157639)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157639
Approved by: https://github.com/yewentao256, https://github.com/jansel
ghstack dependencies: #157638
2025-07-06 06:34:25 +00:00
soulitzer
554b568040 Add internal use only utility to allow externally visible side effects within HOPs (#155715)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155715
Approved by: https://github.com/zou3519
2025-06-21 03:55:28 +00:00
Yidi Wu
fc859077a0 [export][cond] support merging constant ints as unbacked symint (#152742)
@pianpwk points out that this will be helpful to address several data dependent issues in huggingface [models](e23705e557/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py (L332)) with the following pattern:
```python
idx = return 0 if u0 else return 1
return  x[idx]
```
We could preserve the conditional with a cond.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152742
Approved by: https://github.com/zou3519
2025-05-22 17:25:38 +00:00
soulitzer
f2af30fee5 Add a HOP to bypass tracing of a wrapper function while tracing the wrapped function (#153487)
Usage:
```python
from torch._higher_order_ops.wrap import dynamo_bypassing_wrapper

# Your ordinary function wrapper
def my_hop_fn_impl(fn, *args, k=1, **kwargs):
    def wrapper(*args, **kwargs):
        out = fn(*args, **kwargs)
        if isinstance(out, tuple):
            return (out[0] + k,)
        return out + k

    return wrapper

# Calling `my_hop_fn` instead of the impl directly captures a HOP into the dynamo graph
def my_hop_fn(fn, *args, k=1, **kwargs):
    return dynamo_bypassing_wrapper(
        functools.partial(my_hop_fn_impl, k=k), fn, *args, **kwargs
    )
```

Notes:
- The dynamo captured graph now stashes arbitrary callable objects (the wrapper_fn) - this is equivalent to what SAC does today with policy_fn.
- The `wrapper_fn` passed to `dynamo_bypassing_wrapper ` should have signature `Callable -> Callable`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153487
Approved by: https://github.com/ydwu4
2025-05-22 04:24:38 +00:00
Thomas Bohnstingl
68034198e5 [HOP] Mutation and alias rework (#146658)
This PR reworks the way the input mutations and various aliases are checked

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146658
Approved by: https://github.com/ydwu4
2025-05-18 08:05:22 +00:00
PyTorch MergeBot
641e4bee67 Revert "[export][cond] support merging constant ints as unbacked symint (#152742)"
This reverts commit a805911d15.

Reverted https://github.com/pytorch/pytorch/pull/152742 on behalf of https://github.com/ydwu4 due to breaking trunk ([comment](https://github.com/pytorch/pytorch/pull/152742#issuecomment-2874410372))
2025-05-12 23:06:33 +00:00
Yidi Wu
a805911d15 [export][cond] support merging constant ints as unbacked symint (#152742)
@pianpwk points out that this will be helpful to address several data dependent issues in huggingface [models](e23705e557/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py (L332)) with the following pattern:
```python
idx = if u0 return 0 else return 1
return  x[idx]
```
We could preserve the conditional with a cond.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152742
Approved by: https://github.com/zou3519
2025-05-12 20:26:31 +00:00
Yidi Wu
ceb009baee [map] always turn on dynamo for map (#152041)
Summary:
X-link: https://github.com/pytorch/executorch/pull/10409

Reland D72896450

Make map consistent with other control flow ops. After the change, map is able to support accessing closures in the map fn.

Test Plan: See existing tests.

Reviewed By: zou3519

Differential Revision: D73138427

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152041
Approved by: https://github.com/zou3519
2025-05-12 02:10:08 +00:00
PyTorch MergeBot
4a47dd9b3f Revert "[map] always turn on dynamo for map (#150962)"
This reverts commit a72d56cb6b.

Reverted https://github.com/pytorch/pytorch/pull/150962 on behalf of https://github.com/Camyll due to breaking internal builds {SHORT_REASON} ([comment](https://github.com/pytorch/pytorch/pull/150962#issuecomment-2803006282))
2025-04-14 21:09:22 +00:00
Yidi Wu
a72d56cb6b [map] always turn on dynamo for map (#150962)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150962
Approved by: https://github.com/zou3519
2025-04-11 23:28:06 +00:00
bobrenjc93
f649ee73ce Use source hashing to generate consistent symbolic ids (#149665)
This PR was inspired by internal models that were cache missing due to PGO. At a high level the problem looks as follows

Run 1, Invocation 1: We do static compile, save some example values in PGO/automatic dynamic

Run 1, Invocation 2: We detect varying inputs, do dynamic compile, get a dynamic graph and save to PGO. Crucially what we save to PGO is actually a superset of what is actually dynamic. If we notice an input was varying, we mark it as dynamic in PGO even if later on that value gets specialized. When a value gets specialized, we actually remove the symbol from the graph. This results in an interesting conundrum where although we are producing the same isomorphic graph, PGO makes the second run cache miss. Let's see how....

Run 2, Invocation 1: We fetch the PGO, over-mark things as dynamic, get a fx graph, look it up in the cache and... whoops! cache miss! This is because of the aforementioned behavior where the PGO profile will cause us to over-allocate symbols. In practice this means we end up saving a graph in cache with symbols x:s1, y:s3 and on second attempt we cache miss with x:s1, y:s6 where symbols s3,s4,s5 were all optimistically marked dynamic by PGO and subsequently specialized.

We solve this problem by hashing the source names. This ensures somewhat stable assignment. To prevent catastrophic symbol collisions, we use linear probing to ensure no collisions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149665
Approved by: https://github.com/Mingming-Ding, https://github.com/laithsakka
2025-03-28 05:36:32 +00:00
PyTorch MergeBot
af7719a2fa Revert "Use source hashing to generate consistent symbolic ids (#149665)"
This reverts commit 1f92348dc6.

Reverted https://github.com/pytorch/pytorch/pull/149665 on behalf of https://github.com/malfet due to Broke trunk, see 6eb3c2e282/1 ([comment](https://github.com/pytorch/pytorch/pull/149665#issuecomment-2758578187))
2025-03-27 16:02:27 +00:00
bobrenjc93
1f92348dc6 Use source hashing to generate consistent symbolic ids (#149665)
This PR was inspired by internal models that were cache missing due to PGO. At a high level the problem looks as follows

Run 1, Invocation 1: We do static compile, save some example values in PGO/automatic dynamic

Run 1, Invocation 2: We detect varying inputs, do dynamic compile, get a dynamic graph and save to PGO. Crucially what we save to PGO is actually a superset of what is actually dynamic. If we notice an input was varying, we mark it as dynamic in PGO even if later on that value gets specialized. When a value gets specialized, we actually remove the symbol from the graph. This results in an interesting conundrum where although we are producing the same isomorphic graph, PGO makes the second run cache miss. Let's see how....

Run 2, Invocation 1: We fetch the PGO, over-mark things as dynamic, get a fx graph, look it up in the cache and... whoops! cache miss! This is because of the aforementioned behavior where the PGO profile will cause us to over-allocate symbols. In practice this means we end up saving a graph in cache with symbols x:s1, y:s3 and on second attempt we cache miss with x:s1, y:s6 where symbols s3,s4,s5 were all optimistically marked dynamic by PGO and subsequently specialized.

We solve this problem by hashing the source names. This ensures somewhat stable assignment. To prevent catastrophic symbol collisions, we use linear probing to ensure no collisions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149665
Approved by: https://github.com/Mingming-Ding, https://github.com/laithsakka
2025-03-27 03:39:27 +00:00
Yidi Wu
0a0a73a9a9 [cond] don't trace fw and bw graph in autograd key (#148930)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148930
Approved by: https://github.com/zou3519
2025-03-24 17:07:29 +00:00
Guilherme Leobas
406d464d97 Add is_batchedtensor to dynamo builder (#149541)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149541
Approved by: https://github.com/zou3519
2025-03-20 20:46:15 +00:00
Yidi Wu
824474cb35 [cond] support output sizes mismatch in front end (#147130)
This PR finishes https://github.com/pytorch/pytorch/pull/137615 by addressing the TODOs and comments left there.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147130
Approved by: https://github.com/zou3519
2025-02-25 20:28:41 +00:00
eellison
92b7e610ab [Inductor changes] Invoke Quant (#139102)
Adds a `invoke_quant` higher order operator as proposed [here](https://docs.google.com/document/d/1s2PfJlq6Q1F8l11CkTIC69BW1rEnGEgs6YmBC7hu8rA/edit?tab=t.0).

The primary motivations are

- Unifying scattered reasoning for quant operators throughout the code base

- Easy of pattern matching - see this very large pattern match expression [here](949fdd2997/torch/_inductor/fx_passes/post_grad.py (L390-L426). Compared to the pattern I have in the tests:

```
        @register_graph_pattern(
            CallFunction(
                torch.ops.aten.mm,
                CallFunction(
                    torch.ops.higher_order.invoke_quant,
                    Ignored(),
                    Ignored(),
                    Ignored(),
                    scheme="nf4",
                ),
                Arg(),
            ),
            pass_dict=test_pass,
        )
```

- Ability to specify inductor specific logic, like codegen'ing the operators in lower precision, or forcing fusion to a matmul.

Example graph:

``` Python
 ===== AFTER POST GRAD =====
 /data/users/eellison/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"):
         # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self)  # type: ignore[call-arg]
        repeated_subgraph0 = self.repeated_subgraph0
        invoke_quant: "f32[8][1]cpu" = torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4');  repeated_subgraph0 = arg0_1 = arg1_1 = None
        return (invoke_quant,)

    class repeated_subgraph0(torch.nn.Module):
        def forward(self, arg0_1: "f32[8][1]cpu", arg1_1: "f32[8][1]cpu"):
             # File: /data/users/eellison/pytorch/torch/_higher_order_ops/invoke_quant.py:87 in __call__, code: return invoke_quant_tracer(*args, **kwargs, quant_options=self)  # type: ignore[call-arg]
            mul: "f32[8][1]cpu" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1);  arg0_1 = None
            add: "f32[8][1]cpu" = torch.ops.aten.add.Tensor(mul, arg1_1);  mul = arg1_1 = None
            return add
```

The schema for `invoke_quant` is `torch.ops.higher_order.invoke_quant(subgraph, *args, scheme=None)` where the scheme will not always be present.

I wasn't sure exactly how the inductor specific configurations like `codgen_in_low_precision` should be passed through. I didnt want to stuff them all in as kwargs, and I didn't want to have them affect pattern matching. So they will be stored as meta of the node itself. And, following that, I wanted the invocation of the hop to match how it will show up in the graph. So I decided to have it be an object that is then invoked for the tracing.

```
invoke_quant = InvokeQuant(codegen_low_precision=True)
invoke_quant(gn, (x, y), scheme="nf4")
```
Todo - not require the packing of args in a tuple, will do following https://github.com/pytorch/pytorch/pull/139162.

Feedback welcome.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139102
Approved by: https://github.com/Chillee
2025-02-08 19:30:19 +00:00
Yanbo Liang
bd8d7b1b74 [Dynamo][Trace PyDispatcher] Remove disable from HigherOrderOperator.__call__ (#146270)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146270
Approved by: https://github.com/zou3519
2025-02-03 21:47:54 +00:00
Simon Fan
2e197c8a2d [dynamo][hop] test torch.compiling all HOPs (#145422)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145422
Approved by: https://github.com/ydwu4, https://github.com/zou3519
2025-01-31 20:45:22 +00:00
Ryan Guo
eaec97ab1f [dynamo] Properly prune dead input cell object (#145781)
This patch models input cell object as "newly created" rather than
"pre-existing" python object (see added documentation for why this
actually captures the semantics more accurately).

This enables the `SideEffects.prune_dead_object_new` algorithm to prune
away writes to input cell objects which are no longer relevant; this
didn't happen prior to this patch because we modelled them as
pre-existing objects, which forces us to codegen their attribute
mutations.

Fixes #145564.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145781
Approved by: https://github.com/williamwen42, https://github.com/jansel
2025-01-28 18:28:13 +00:00
Animesh Jain
19584b28fd [dynamo][dicts] Consolidate dict(..) construction (#144342)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144342
Approved by: https://github.com/StrongerXi
2025-01-20 04:42:06 +00:00
PyTorch MergeBot
5e6e6200bf Revert "[dynamo][dicts] Consolidate dict(..) construction (#144342)"
This reverts commit a54a784b82.

Reverted https://github.com/pytorch/pytorch/pull/144342 on behalf of https://github.com/kit1980 due to breaking internal builds, see D68125388 ([comment](https://github.com/pytorch/pytorch/pull/144342#issuecomment-2597184167))
2025-01-17 00:32:09 +00:00
Animesh Jain
a54a784b82 [dynamo][dicts] Consolidate dict(..) construction (#144342)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144342
Approved by: https://github.com/StrongerXi
2025-01-13 22:24:56 +00:00
Yidi Wu
c36f94b373 [while_loop][dynamo] auto-unspecialize int input and output to unbacked symints (#143106)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143106
Approved by: https://github.com/zou3519
ghstack dependencies: #143105, #143545
2025-01-03 19:01:07 +00:00
Tom Ritchford
d25e6e623f Fix unused Python variables in test/[a-d]* (#134665)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134665
Approved by: https://github.com/albanD
2024-12-13 22:13:12 +00:00
Yidi Wu
7111cd6ee0 [hop][BE] add util diff_meta with prettier error message. (#142162)
The error message changes from:
```python
-torch._dynamo.exc.Unsupported: Expected branches to return tensors with same metadata. [(tensor_pair, difference)...]:[('pair0:', TensorMetadata(shape=torch.Size([4, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}), TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}))]
```
to
```python
+torch._dynamo.exc.Unsupported: Expect branches to return tensors with same metadata but find pair[0] differ in 'shape', where lhs is TensorMetadata(shape=torch.Size([4, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={}) and rhs is TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=None, is_quantized=False, qparams={})
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142162
Approved by: https://github.com/zou3519
2024-12-10 21:54:28 +00:00
Yidi Wu
9ced54a51a [hop] lift free symbols in slice (#142385)
Before the change, we get an unfound proxy error when linting the subgraph.

After the change, we have the following dynamo graph for dynamic_shape test.

```python
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]  /data/users/yidi/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]     def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", s2: "Sym(s2)", L_x_: "f32[s0, s1, s2][s1*s2, s2, 1]cpu"):
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]         l_x_ = L_x_
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]          # File: /data/users/yidi/pytorch/test/dynamo/test_higher_order_ops.py:307 in f, code: i = x.size(0) - 2
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]         sub: "Sym(s0 - 2)" = s0 - 2
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]          # File: /data/users/yidi/pytorch/test/dynamo/test_higher_order_ops.py:308 in f, code: j = x.size(1) - 3
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]         sub_1: "Sym(s1 - 3)" = s1 - 3
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]          # File: /data/users/yidi/pytorch/test/dynamo/test_higher_order_ops.py:310 in f, code: return wrap(lambda x: x[:i, :j, k:], x)
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]         wrap_body_0 = self.wrap_body_0
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]         wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, s1, s2, l_x_, sub, sub_1);  wrap_body_0 = s0 = s1 = s2 = l_x_ = sub = sub_1 = None
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]         getitem: "f32[s0 - 2, s1 - 3, 0][s1*s2, s2, 1]cpu" = wrap[0];  wrap = None
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]         return (getitem,)
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]     class wrap_body_0(torch.nn.Module):
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]         def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", s2: "Sym(s2)", l_x_: "f32[s0, s1, s2][s1*s2, s2, 1]cpu", sub: "Sym(s0 - 2)", sub_1: "Sym(s1 - 3)"):
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]              # File: /data/users/yidi/pytorch/test/dynamo/test_higher_order_ops.py:310 in <lambda>, code: return wrap(lambda x: x[:i, :j, k:], x)
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]             getitem: "f32[s0 - 2, s1 - 3, 0][s1*s2, s2, 1]cpu" = l_x_[(slice(None, sub, None), slice(None, sub_1, None), slice(s2, None, None))];  l_x_ = sub = sub_1 = s2 = None
V1209 11:11:06.187000 4091124 torch/_dynamo/output_graph.py:1346] [0/2] [__graph_code]             return (getitem,)
```

We lift sub, sub_1 because they're compound expressions and are directly used in argument of the getitem node. We lift s0, s1 and s2 because they're basic symbols in the tensor input.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142385
Approved by: https://github.com/zou3519
2024-12-10 21:52:30 +00:00
Ryan Guo
9d54cd1504 [dynamo] Undo some jvp old workarounds in functorch (#142081)
This basically undoes some workarounds introduced in #119926, the
root causes of which have been fixed by #142078 and other changes in
Dynamo.

Now that Dynamo traces the spec comparison code, the test also needs update:
- removing the `_jvp_treespec_compare` calls in fx graph

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142081
Approved by: https://github.com/zou3519
ghstack dependencies: #142078, #142080
2024-12-06 08:06:53 +00:00
Ryan Guo
59de5e867b [dynamo] Undo some vjp old workarounds in functorch (#142080)
This basically undoes most of the workarounds introduced in #119405, the
root causes of which have been fixed by #142078 and other changes in
Dynamo.

Now that Dynamo traces the spec comparison code, the test also needs update:
1. renaming `o` to `pimals_out`
2. removing the `_vjp_treespec_compare` calls in fx graph

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142080
Approved by: https://github.com/zou3519
ghstack dependencies: #142078
2024-12-06 08:06:53 +00:00
PyTorch MergeBot
ad37afd590 Revert "Always unspecialize float in OSS (#138922)"
This reverts commit ba5253da9b.

Reverted https://github.com/pytorch/pytorch/pull/138922 on behalf of https://github.com/yf225 due to perf regression on torchbench ([comment](https://github.com/pytorch/pytorch/pull/138922#issuecomment-2499277511))
2024-11-26 00:03:03 +00:00
Bob Ren
ba5253da9b Always unspecialize float in OSS (#138922)
Fixes https://github.com/pytorch/pytorch/issues/107277

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138922
Approved by: https://github.com/ezyang

Co-authored-by: Edward Z. Yang <ezyang@meta.com>
2024-11-24 01:58:13 +00:00
PyTorch MergeBot
a8c90e5140 Revert "Always unspecialize float in OSS (#138922)"
This reverts commit 6d779d0549.

Reverted https://github.com/pytorch/pytorch/pull/138922 on behalf of https://github.com/huydhn due to Sorry for reverting your change but there is some slow tests failing after this land ([comment](https://github.com/pytorch/pytorch/pull/138922#issuecomment-2495076878))
2024-11-22 23:18:36 +00:00
Bob Ren
6d779d0549 Always unspecialize float in OSS (#138922)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138922
Approved by: https://github.com/ezyang

Co-authored-by: Edward Z. Yang <ezyang@meta.com>
2024-11-22 17:54:42 +00:00
Guilherme Leobas
7ced49d2cc Raise exception if vmap (eager) calls compiled function (#140439)
Fixes #138422

This is not a proper fix for #140439, but more of a way to prevent a user from seeing a nasty error inside the C++ code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140439
Approved by: https://github.com/zou3519
2024-11-19 16:27:48 +00:00
Ryan Guo
ea1d11cf74 [dynamo] Represent all cells as NewCellVariable (#140153)
In addition to `NewCellVariable`, Dynamo has 3 ways of modeling cell objects:
1. For cells captured and created by the root frame, represent them as
   their contents in `root_tx.symbolic_locals`, which `LOAD_DEREF` and
   `STORE_DEREF` update directly, without going through `SideEffects`.
2. `ClosureVariable`: this is created when cells from (1) are captured
   by a newly created function Dynamo is about to inline. It's a handle
   with a name that redirects `LOAD_DEREF` and `STORE_DEREF` back (1),
   to make `root_tx.symbolic_locals` up-to-date.
3. For cells that are captured by both the root frame and some
   pre-existing function Dynamo is about to inline, represent those
   cells as contents, and do not allow writes to them.

Note that (2) and (3) are mainly to conform with (1) -- to make sure
Dynamo has a consistent modeling of cells for the same cell objects.

In this patch, we represent all of these cells as `NewCellVariable`. The
main new code paths introduced are:
- using `NewCellVariable` to model cell objects created by the root
  frame (the cells are passed in as input to `InstructionTranslator`),
  this is what allows us to get rid of all 3 legacy paths above.
- adding a new `AutoDerefLocalSource` to deal with the python-code
  level (guards) and bytecode level (codegen) auto-dereferencing
  behavior, when accessing pre-existing python cells. This also
  involves a tiny update to guard manager generation.
- plumbing some extra info into `LocalSource` and `CellVariable` so that
  we can still emit `LOAD_DEREF`, `STORE_DEREF`, `LOAD_CLOSURE` (instead
  of `make_cell`, `cell_contents` attribute access, and `LOAD_FAST`),
  which is important for readability, performance, and some
  assumptions `bytecode_transformation.py` makes.

As a result, this patch removes a lot of the now-dead code paths and
TODOs. Notably, it significantly simplified the `prune_dead_locals`
function, which was duplicating a lot of the logic from
`prune_dead_object_new`; this conveniently closes #137123.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140153
Approved by: https://github.com/jansel
ghstack dependencies: #140330, #140152, #140436, #140435
2024-11-15 17:17:30 +00:00