Commit Graph

175 Commits

Author SHA1 Message Date
Sherlock Huang
34d6ef7022 Update gm.print_readable to include Annotation (#165397)
Sample output
```
[rank0]:        # Annotation: {'compile_with_inductor': 'flex_attention'} File: /data/users/bahuang/pytorch/torch/nn/attention/flex_attention.py:1490 in flex_attention, code: out, lse, max_scores = flex_attention_hop(
[rank0]:        score_mod_2 = self.score_mod_2
[rank0]:        mask_fn_2 = self.mask_fn_2
[rank0]:        flex_attention_1 = torch.ops.higher_order.flex_attention(xq_5, xk_5, xv_3, score_mod_2, (2048, 2048, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___kv_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___kv_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_kv_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_kv_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___q_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___q_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_q_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_q_indices, 128, 128, mask_fn_2), 0.25, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___mask_mod___closure___0_cell_contents,));  xq_5 = xk_5 = xv_3 = score_mod_2 = mask_fn_2 = None
[rank0]:        out_2: "bf16[8, 4, 2048, 16]" = flex_attention_1[0];  flex_attention_1 = None
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165397
Approved by: https://github.com/yushangdi, https://github.com/anijain2305, https://github.com/mlazos
2025-10-28 13:54:38 +00:00
PyTorch MergeBot
e50dc40d28 Revert "Update gm.print_readable to include Annotation (#165397)"
This reverts commit 7a65770013.

Reverted https://github.com/pytorch/pytorch/pull/165397 on behalf of https://github.com/malfet due to I don't know how/why, but it breaks windows tests, see 2e22b1a61e/1 ([comment](https://github.com/pytorch/pytorch/pull/165397#issuecomment-3417428128))
2025-10-17 22:35:50 +00:00
Sherlock Huang
7a65770013 Update gm.print_readable to include Annotation (#165397)
Sample output
```
[rank0]:        # Annotation: {'compile_with_inductor': 'flex_attention'} File: /data/users/bahuang/pytorch/torch/nn/attention/flex_attention.py:1490 in flex_attention, code: out, lse, max_scores = flex_attention_hop(
[rank0]:        score_mod_2 = self.score_mod_2
[rank0]:        mask_fn_2 = self.mask_fn_2
[rank0]:        flex_attention_1 = torch.ops.higher_order.flex_attention(xq_5, xk_5, xv_3, score_mod_2, (2048, 2048, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___kv_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___kv_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_kv_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_kv_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___q_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___q_indices, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_q_num_blocks, g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___full_q_indices, 128, 128, mask_fn_2), 0.25, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), (g____import_torchtitan_dot_models_dot_attention___flex_attention_block_masks___block_causal___none___mask_mod___closure___0_cell_contents,));  xq_5 = xk_5 = xv_3 = score_mod_2 = mask_fn_2 = None
[rank0]:        out_2: "bf16[8, 4, 2048, 16]" = flex_attention_1[0];  flex_attention_1 = None
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165397
Approved by: https://github.com/yushangdi, https://github.com/anijain2305
2025-10-17 18:35:18 +00:00
Yuanyuan Chen
e925dfcc6b Enable all SIM rules except disabled ones (#164645)
`SIM` rules are useful for simplifying boolean expressions and enhances code readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164645
Approved by: https://github.com/ezyang, https://github.com/mlazos
2025-10-17 07:27:11 +00:00
Yuanyuan Chen
8de85896e0 Enable ruff rule E721 (#165162)
`E721` checks for object type comparisons using == and other comparison operators. This is useful because it is recommended to use `is` for type comparisons.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165162
Approved by: https://github.com/Skylion007
2025-10-13 01:48:55 +00:00
PyTorch MergeBot
816fb7f48d Revert "Enable ruff rule E721 (#165162)"
This reverts commit 9e7c19f72b.

Reverted https://github.com/pytorch/pytorch/pull/165162 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/165162#issuecomment-3393328271))
2025-10-11 13:25:40 +00:00
Yuanyuan Chen
9e7c19f72b Enable ruff rule E721 (#165162)
`E721` checks for object type comparisons using == and other comparison operators. This is useful because it is recommended to use `is` for type comparisons.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165162
Approved by: https://github.com/Skylion007
2025-10-11 06:43:53 +00:00
PyTorch MergeBot
5d7360bb03 Revert "Enable all SIM rules except disabled ones (#164645)"
This reverts commit 321e602692.

Reverted https://github.com/pytorch/pytorch/pull/164645 on behalf of https://github.com/izaitsevfb due to causes lint failures ([comment](https://github.com/pytorch/pytorch/pull/164645#issuecomment-3369274351))
2025-10-05 19:32:21 +00:00
Yuanyuan Chen
321e602692 Enable all SIM rules except disabled ones (#164645)
`SIM` rules are useful for simplifying boolean expressions and enhances code readability.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164645
Approved by: https://github.com/ezyang
2025-10-05 07:38:25 +00:00
bobrenjc93
8c54101933 add tensor subclass printing support in fx/graph.py (#164403)
it was previously quite misleading since it looks like the inputs to the
dynamo graph are plain tensors when in reality they are tensor subclasses

before
```
class GraphModule(torch.nn.Module):
    def forward(self, L_input_batch_inputs_: "i64[2, 512][512, 1]cuda:0", L_self_parameters_weight_: "f32[202048, 256][256, 1]cuda:0"):
```

after
```
    class GraphModule(torch.nn.Module):
        def forward(self, L_input_batch_inputs_: "DTensor(i64[2, 512][512, 1]cuda:0)", L_self_parameters_weight_: "DTensor(f32[202048, 256][256, 1]cuda:0)"):
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164403
Approved by: https://github.com/ezyang
2025-10-02 20:06:12 +00:00
Yuanyuan Chen
a8c528c105 [1/N] Apply UP035 rule in tests (#163947)
Apply UP035 `ruff` rule in tests, but some tests for `fx` and `dynamo` are excluded in case the old typing is the test target.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163947
Approved by: https://github.com/ezyang
2025-09-29 01:42:01 +00:00
ghostspiders
af10f1f86c Fix requires_cuda to requires_cuda_and_triton (#160222)
Fixes ##159399

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160222
Approved by: https://github.com/janeyx99
2025-08-10 07:05:52 +00:00
gaoyvfeng
50f23ff6f8 rename-HAS_CUDA-to-HAS_CUDA_AND_TRITON (#159883)
Fixes #159399
"Modified torch.testing._internal.inductor_utils and test/inductor"

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159883
Approved by: https://github.com/janeyx99
2025-08-08 15:44:52 +00:00
Edward Z. Yang
204eb4da5e Add expanded_def option for FX printing, render descriptor, update tests (#158708)
----

- First, we add a new expanded_def to FX, which will expand the
  definitions of variables into multiple lines, one per variable
  definition.  This makes extremely long args/return lists much
  more readable.

- Next, we extend this mechanism to also print out descriptors on
  placeholders and return values, as comments, if available.  This
  is how we will test descriptors.

- We update tlparse for AOTAutograd to use this format.

- We update expect tests to use this format and update their formats,
  so you can inspect what it can look at.  There may be other tests
  I should update, open to suggestions.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158708
Approved by: https://github.com/wconstab
ghstack dependencies: #158624
2025-07-25 13:22:32 +00:00
Xuehai Pan
02715d0876 [BE][5/6] fix typos in test/ (test/dynamo/) (#157639)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157639
Approved by: https://github.com/yewentao256, https://github.com/jansel
ghstack dependencies: #157638
2025-07-06 06:34:25 +00:00
Xuehai Pan
6d5c789ad5 [BE][PYFMT] migrate PYFMT for test/[a-h]*/ to ruff format (#144555)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144555
Approved by: https://github.com/ezyang
ghstack dependencies: #144551, #144554
2025-06-24 04:53:54 +00:00
Sean McGovern
297805fd8f Typo fixes for "overridden" in comments and function names (#155944)
This word appears often in class descriptions and is not consistently spelled. Update comments and some function names to use the correct spelling consistently. Facilitates searching the codebase.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155944
Approved by: https://github.com/Skylion007
2025-06-14 03:37:38 +00:00
PyTorch MergeBot
06408dae49 Revert "Add view_simple as meta function for view, and avoid calling reshape_view_helper. (#154757)"
This reverts commit 0029259bdf.

Reverted https://github.com/pytorch/pytorch/pull/154757 on behalf of https://github.com/laithsakka due to post land issue ([comment](https://github.com/pytorch/pytorch/pull/154757#issuecomment-2971385787))
2025-06-13 19:11:43 +00:00
angelayi
0860606729 [export] Add meta[val] to getattr nodes (#154934)
Fixes [P1830293318](https://www.internalfb.com/intern/paste/P1830293318/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154934
Approved by: https://github.com/yushangdi, https://github.com/muchulee8
2025-06-13 05:48:21 +00:00
Laith Sakka
0029259bdf Add view_simple as meta function for view, and avoid calling reshape_view_helper. (#154757)
address https://github.com/pytorch/pytorch/issues/153303

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154757
Approved by: https://github.com/bobrenjc93, https://github.com/leslie-fang-intel
2025-06-12 09:58:15 +00:00
Aaron Orenstein
fc0135ca11 Re-enable FakeTensor caching for SymInts (#152662)
Summary:

This backs out D60320595 which itself turned off FakeTensor caching when a SymInt was present.

There has been a lot of dynamic shape fixes done this year and tests pass so I'm assuming some of that work fixed what was breaking previously.

Test Plan: Reran the tests listed in T196779132 and they pass.

## Perf
### Instruction Counter Benchmark:
- 26% win on add_loop_eager_dynamic
- 13% win on add_loop_inductor_dynamic_gpu
### Perf Dashboard
Compilation Latency wins across the board but especially strong on the dynamic tests (like cudagraphs_dynamic) - for example MobileBertForMaskedLM went from 66s -> 50s.

Differential Revision: [D75467694](https://our.internmc.facebook.com/intern/diff/D75467694)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152662
Approved by: https://github.com/anijain2305
2025-05-30 17:23:36 +00:00
Ryan Guo
75bbd4989c [dynamo] Support using symint from dispatcher-style tensor subclass (#154130)
Fixes #146932.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154130
Approved by: https://github.com/laithsakka
2025-05-27 19:05:46 +00:00
PyTorch MergeBot
3f64502c98 Revert "Re-enable FakeTensor caching for SymInts (#152662)"
This reverts commit 7d11c61c26.

Reverted https://github.com/pytorch/pytorch/pull/152662 on behalf of https://github.com/malfet due to Looks like it broke bunch of inductor tests, see 187d38185e/1 ([comment](https://github.com/pytorch/pytorch/pull/152662#issuecomment-2910293593))
2025-05-26 17:13:22 +00:00
Aaron Orenstein
7d11c61c26 Re-enable FakeTensor caching for SymInts (#152662)
Summary:

This backs out D60320595 which itself turned off FakeTensor caching when a SymInt was present.

There has been a lot of dynamic shape fixes done this year and tests pass so I'm assuming some of that work fixed what was breaking previously.

Test Plan: Reran the tests listed in T196779132 and they pass.

## Perf
### Instruction Counter Benchmark:
- 26% win on add_loop_eager_dynamic
- 13% win on add_loop_inductor_dynamic_gpu
### Perf Dashboard
Compilation Latency wins across the board but especially strong on the dynamic tests (like cudagraphs_dynamic) - for example MobileBertForMaskedLM went from 66s -> 50s.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152662
Approved by: https://github.com/anijain2305
2025-05-26 04:17:56 +00:00
Ryan Guo
18e13a67ce [dynamo] Harden torch function dispatchability check for attributes and methods access (#153082)
See more details in
https://github.com/pytorch/pytorch/issues/151771#issuecomment-2836372110.

Fixes #151771.

Differential Revision: [D74342291](https://our.internmc.facebook.com/intern/diff/D74342291)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153082
Approved by: https://github.com/mlazos
2025-05-09 16:14:23 +00:00
Ryan Guo
51e77f3b30 [dynamo] replace unimplemented with unimplemented_v2 in variables/torch_functions.py (#151278)
This addresses part of #147913.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151278
Approved by: https://github.com/Skylion007, https://github.com/williamwen42
ghstack dependencies: #151277
2025-05-05 18:45:40 +00:00
Simon Fan
c461ba6522 [aot] mark dynamic activations as maybe dynamic (#149707)
Today, we mark graph outputs as maybe dynamic, this lets a compilation to communicate to future compilations whether certain graph inputs are dynamic. Similarly, we can do this to saved activations, which may be used in future compilations as well. This is especially prevalent in compiled autograd, where tensor activations will always become graph inputs.

Changes to the tests were mainly cosmetic, with the exception of tests that relied on duck shaping. By annotating tensor dims, we prevent them from reusing pre-existing symbols, so this change will make graphs use duck shapes less than before, which affects some of the caching tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149707
Approved by: https://github.com/bdhirsh
2025-05-01 21:59:36 +00:00
Brian Hirsh
4a63cab624 [cudagraphs] Fix issue in collecting static_input_idxs (#152287)
related to https://github.com/pytorch/pytorch/issues/152275

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152287
Approved by: https://github.com/bdhirsh, https://github.com/eellison

Co-authored-by: Brian Hirsh <hirsheybar@fb.com>
2025-04-30 03:24:05 +00:00
PyTorch MergeBot
a6d19fcfac Revert "[cudagraphs] Fix issue in collecting static_input_idxs (#152287)"
This reverts commit 75a564608a.

Reverted https://github.com/pytorch/pytorch/pull/152287 on behalf of https://github.com/wdvr due to causing ao failures - discussed with author ([comment](https://github.com/pytorch/pytorch/pull/152287#issuecomment-2837686127))
2025-04-29 06:57:06 +00:00
Animesh Jain
75a564608a [cudagraphs] Fix issue in collecting static_input_idxs (#152287)
related to https://github.com/pytorch/pytorch/issues/152275

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152287
Approved by: https://github.com/bdhirsh, https://github.com/eellison
2025-04-28 23:07:52 +00:00
Ryan Guo
f66229de2b [dynamo] Remove traceable_tensor_subclasses-related code (#151062)
Since #149792 deprecates `traceable_tensor_subclasses` and it's been
landed for over a week, we can safely remove all the old code that uses
`traceable_tensor_subclasses` (they were primarily for testing purposes
and are equivalent to no-ops now).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151062
Approved by: https://github.com/mlazos, https://github.com/anijain2305
ghstack dependencies: #151060, #151061
2025-04-15 03:55:35 +00:00
Ryan Guo
6a1499d209 [dynamo] handle tensor subclass with non-classmethod __torch_function__ (#151061)
As title, this patch fixes bugs in
1. emulating `has_torch_function`
2. emulating calling `__torch_function__`
3. building a callable VT for non-classmethod `__torch_function__`

Fixes #120799, #150265, #150848.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151061
Approved by: https://github.com/anijain2305, https://github.com/mlazos
ghstack dependencies: #151060
2025-04-15 03:55:34 +00:00
Ryan Guo
bb98749230 [dynamo] Always trace into tensor subclass __torch_function__ (#149792)
This patch effectively ignores traceable_tensor_subclasses, allowing
Dynamo to always try tracing into the `__torch_function__` of tensor
subclass. This helps us with 2 things:
1. allowing users to directly benefit from better compilation of tensor
   subclass, by just upgrading pytorch, without having to change legacy
   library code (see earlier patches in the stack for examples).
2. potentially exposing more issues in compiling tensor subclass, so we
   can get signals and improve them.

As a consequence, it exposed and fixes 2 subtle bugs:
1. In `build_torch_function_fn`, we could get
   `torch._C._disabled_torch_function_impl` because we have a
   `Parameter` subclass without `__torch_function__` override or if we
   have a tensor subclass with `__torch_dispatch__` override. We graph
   break on this for now, and plan to add support -- the logic for
   simulating `torch._C._disabled_torch_function_impl` is already in
   `SuperVariable`, we just need to reuse it.
2. Sometimes we create `SyntheticLocalSource` and need to remove all the
   guards installed on it, but we only removed the ones whose source
   _is_ the created synthetic source `s`, but forgot about chained
   source like `s.foo`, this showed up as
   `SYNTHETIC_LOCAL['tmp_0'].__torch_function__.__func__`.

Differential Revision: [D71906141](https://our.internmc.facebook.com/intern/diff/D71906141)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149792
Approved by: https://github.com/jansel, https://github.com/mlazos
ghstack dependencies: #149482, #149483, #149484
2025-04-02 20:57:00 +00:00
Ryan Guo
3463ea1059 [dynamo] Support tensor subclass with overriden tensor methods and properties (#149484)
This fixes most of the "torch.compile X tensor-subclass" issues
encountered in https://github.com/city96/ComfyUI-GGUF/issues/118. The
relevant tensor subclass definition is here:
298192ed60/ops.py (L18-L65).

A few things to note about the tensor subclass:
1. it overrides a lot of the `torch.Tensor` methods (e.g., `to`,
   `clone`), so this patch updates `TensorWithTFOverrideVariable.var_getattr`
   to support that.
2. it overrides the `shape` property, so this patch updates
   `TensorWithTFOverrideVariable.var_getattr` to support property as well.
3. it has calls to `torch.Tensor.size`, which returns `torch.Size`,
   which gets reconstructed in `torch.Tensor.__torch_function__`, so
   this patch adds support for calling `torch.Size(...)` on non-constant
   inputs.

Differential Revision: [D71906137](https://our.internmc.facebook.com/intern/diff/D71906137)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149484
Approved by: https://github.com/jansel, https://github.com/mlazos
ghstack dependencies: #149482, #149483
2025-04-02 20:57:00 +00:00
Ryan Guo
0d4dbfd9ed [dynamo] Support torch.Tensor._make_subclass and tracing through tensor subclass __new__ (#149483)
This builds off the previous patch in the stack, and fully fixes
https://github.com/huggingface/diffusers/issues/10795.

Essentially, tensor subclass in the issue uses
`torch.Tensor._make_subclass`, which has a pretty simple shallow-copy
plus type change semantics, as far as Dynamo is concerned. So this patch
adds a polyfill for it.

As a result, this allows us to trace through many user-defined `__new__`
in tensor subclass (it's similar to how we trace through user-defined
`__new__` for `UserDefinedClassVariable`), so this patch also faithfully
trace through these `__new__` methods.

Differential Revision: [D71906139](https://our.internmc.facebook.com/intern/diff/D71906139)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149483
Approved by: https://github.com/zou3519, https://github.com/mlazos
ghstack dependencies: #149482
2025-04-02 20:56:52 +00:00
Ryan Guo
33535b3eee [dynamo] Support Tensor subclass that has dynamic attributes or calls Parameter.__torch_function__ (#149482)
This fixes most of https://github.com/huggingface/diffusers/issues/10795,
except for `torch.Tensor._make_subclass`, which will be fixed in a
subsequent patch.

The relevant tensor subclass from the aforementioned issue is defined
here: fbf6b856cc/src/diffusers/quantizers/gguf/utils.py (L398-L435).

There are two things to note about the tensor subclass:
1. it calls `super().__torch_function__`, which is
   `torch._C._disabled_torch_function_impl`, so this patch updates
   `SuperVariable.call_method` to handle it (we can't do a simpler
   polyfill due to some bug with `var_getattr` raising
   `NotImplementedError`, which forgot to restore symbolic context).
2. it sets and reads attributes (`quant_type`), and
   defines new methods (`as_data`), so this patch adds support for those.
3. it has a `__init__`, which Dynamo needs to trace through in
   `TensorSubclassVariable.call_function`.

Differential Revision: [D71906140](https://our.internmc.facebook.com/intern/diff/D71906140)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149482
Approved by: https://github.com/jansel, https://github.com/mlazos
2025-04-02 20:56:43 +00:00
PyTorch MergeBot
03c879d59b Revert "[dynamo] Support Tensor subclass that has dynamic attributes or calls Parameter.__torch_function__ (#149482)"
This reverts commit 98453c135a.

Reverted https://github.com/pytorch/pytorch/pull/149482 on behalf of https://github.com/malfet due to Broke trunk, see b03c42109c/1 ([comment](https://github.com/pytorch/pytorch/pull/149482#issuecomment-2773650522))
2025-04-02 20:30:33 +00:00
PyTorch MergeBot
18908c8ced Revert "[dynamo] Support torch.Tensor._make_subclass and tracing through tensor subclass __new__ (#149483)"
This reverts commit 203e1d681d.

Reverted https://github.com/pytorch/pytorch/pull/149483 on behalf of https://github.com/malfet due to Broke trunk, see b03c42109c/1 ([comment](https://github.com/pytorch/pytorch/pull/149482#issuecomment-2773650522))
2025-04-02 20:30:33 +00:00
PyTorch MergeBot
01411c739f Revert "[dynamo] Support tensor subclass with overriden tensor methods and properties (#149484)"
This reverts commit 7e53c58687.

Reverted https://github.com/pytorch/pytorch/pull/149484 on behalf of https://github.com/malfet due to Broke trunk, see b03c42109c/1 ([comment](https://github.com/pytorch/pytorch/pull/149482#issuecomment-2773650522))
2025-04-02 20:30:33 +00:00
PyTorch MergeBot
e545567340 Revert "[dynamo] Always trace into tensor subclass __torch_function__ (#149792)"
This reverts commit 238109ad32.

Reverted https://github.com/pytorch/pytorch/pull/149792 on behalf of https://github.com/malfet due to Broke trunk, see b03c42109c/1 ([comment](https://github.com/pytorch/pytorch/pull/149482#issuecomment-2773650522))
2025-04-02 20:30:32 +00:00
Ryan Guo
238109ad32 [dynamo] Always trace into tensor subclass __torch_function__ (#149792)
This patch effectively ignores traceable_tensor_subclasses, allowing
Dynamo to always try tracing into the `__torch_function__` of tensor
subclass. This helps us with 2 things:
1. allowing users to directly benefit from better compilation of tensor
   subclass, by just upgrading pytorch, without having to change legacy
   library code (see earlier patches in the stack for examples).
2. potentially exposing more issues in compiling tensor subclass, so we
   can get signals and improve them.

As a consequence, it exposed and fixes 2 subtle bugs:
1. In `build_torch_function_fn`, we could get
   `torch._C._disabled_torch_function_impl` because we have a
   `Parameter` subclass without `__torch_function__` override or if we
   have a tensor subclass with `__torch_dispatch__` override. We graph
   break on this for now, and plan to add support -- the logic for
   simulating `torch._C._disabled_torch_function_impl` is already in
   `SuperVariable`, we just need to reuse it.
2. Sometimes we create `SyntheticLocalSource` and need to remove all the
   guards installed on it, but we only removed the ones whose source
   _is_ the created synthetic source `s`, but forgot about chained
   source like `s.foo`, this showed up as
   `SYNTHETIC_LOCAL['tmp_0'].__torch_function__.__func__`.

Differential Revision: [D71906141](https://our.internmc.facebook.com/intern/diff/D71906141)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149792
Approved by: https://github.com/jansel, https://github.com/mlazos
ghstack dependencies: #149482, #149483, #149484
2025-04-02 17:05:25 +00:00
Ryan Guo
7e53c58687 [dynamo] Support tensor subclass with overriden tensor methods and properties (#149484)
This fixes most of the "torch.compile X tensor-subclass" issues
encountered in https://github.com/city96/ComfyUI-GGUF/issues/118. The
relevant tensor subclass definition is here:
298192ed60/ops.py (L18-L65).

A few things to note about the tensor subclass:
1. it overrides a lot of the `torch.Tensor` methods (e.g., `to`,
   `clone`), so this patch updates `TensorWithTFOverrideVariable.var_getattr`
   to support that.
2. it overrides the `shape` property, so this patch updates
   `TensorWithTFOverrideVariable.var_getattr` to support property as well.
3. it has calls to `torch.Tensor.size`, which returns `torch.Size`,
   which gets reconstructed in `torch.Tensor.__torch_function__`, so
   this patch adds support for calling `torch.Size(...)` on non-constant
   inputs.

Differential Revision: [D71906137](https://our.internmc.facebook.com/intern/diff/D71906137)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149484
Approved by: https://github.com/jansel, https://github.com/mlazos
ghstack dependencies: #149482, #149483
2025-04-02 17:05:25 +00:00
Ryan Guo
203e1d681d [dynamo] Support torch.Tensor._make_subclass and tracing through tensor subclass __new__ (#149483)
This builds off the previous patch in the stack, and fully fixes
https://github.com/huggingface/diffusers/issues/10795.

Essentially, tensor subclass in the issue uses
`torch.Tensor._make_subclass`, which has a pretty simple shallow-copy
plus type change semantics, as far as Dynamo is concerned. So this patch
adds a polyfill for it.

As a result, this allows us to trace through many user-defined `__new__`
in tensor subclass (it's similar to how we trace through user-defined
`__new__` for `UserDefinedClassVariable`), so this patch also faithfully
trace through these `__new__` methods.

Differential Revision: [D71906139](https://our.internmc.facebook.com/intern/diff/D71906139)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149483
Approved by: https://github.com/zou3519, https://github.com/mlazos
ghstack dependencies: #149482
2025-04-02 17:05:19 +00:00
Ryan Guo
98453c135a [dynamo] Support Tensor subclass that has dynamic attributes or calls Parameter.__torch_function__ (#149482)
This fixes most of https://github.com/huggingface/diffusers/issues/10795,
except for `torch.Tensor._make_subclass`, which will be fixed in a
subsequent patch.

The relevant tensor subclass from the aforementioned issue is defined
here: fbf6b856cc/src/diffusers/quantizers/gguf/utils.py (L398-L435).

There are two things to note about the tensor subclass:
1. it calls `super().__torch_function__`, which is
   `torch._C._disabled_torch_function_impl`, so this patch updates
   `SuperVariable.call_method` to handle it (we can't do a simpler
   polyfill due to some bug with `var_getattr` raising
   `NotImplementedError`, which forgot to restore symbolic context).
2. it sets and reads attributes (`quant_type`), and
   defines new methods (`as_data`), so this patch adds support for those.
3. it has a `__init__`, which Dynamo needs to trace through in
   `TensorSubclassVariable.call_function`.

Differential Revision: [D71906140](https://our.internmc.facebook.com/intern/diff/D71906140)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149482
Approved by: https://github.com/jansel, https://github.com/mlazos
2025-04-02 17:05:12 +00:00
bobrenjc93
f649ee73ce Use source hashing to generate consistent symbolic ids (#149665)
This PR was inspired by internal models that were cache missing due to PGO. At a high level the problem looks as follows

Run 1, Invocation 1: We do static compile, save some example values in PGO/automatic dynamic

Run 1, Invocation 2: We detect varying inputs, do dynamic compile, get a dynamic graph and save to PGO. Crucially what we save to PGO is actually a superset of what is actually dynamic. If we notice an input was varying, we mark it as dynamic in PGO even if later on that value gets specialized. When a value gets specialized, we actually remove the symbol from the graph. This results in an interesting conundrum where although we are producing the same isomorphic graph, PGO makes the second run cache miss. Let's see how....

Run 2, Invocation 1: We fetch the PGO, over-mark things as dynamic, get a fx graph, look it up in the cache and... whoops! cache miss! This is because of the aforementioned behavior where the PGO profile will cause us to over-allocate symbols. In practice this means we end up saving a graph in cache with symbols x:s1, y:s3 and on second attempt we cache miss with x:s1, y:s6 where symbols s3,s4,s5 were all optimistically marked dynamic by PGO and subsequently specialized.

We solve this problem by hashing the source names. This ensures somewhat stable assignment. To prevent catastrophic symbol collisions, we use linear probing to ensure no collisions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149665
Approved by: https://github.com/Mingming-Ding, https://github.com/laithsakka
2025-03-28 05:36:32 +00:00
PyTorch MergeBot
af7719a2fa Revert "Use source hashing to generate consistent symbolic ids (#149665)"
This reverts commit 1f92348dc6.

Reverted https://github.com/pytorch/pytorch/pull/149665 on behalf of https://github.com/malfet due to Broke trunk, see 6eb3c2e282/1 ([comment](https://github.com/pytorch/pytorch/pull/149665#issuecomment-2758578187))
2025-03-27 16:02:27 +00:00
bobrenjc93
1f92348dc6 Use source hashing to generate consistent symbolic ids (#149665)
This PR was inspired by internal models that were cache missing due to PGO. At a high level the problem looks as follows

Run 1, Invocation 1: We do static compile, save some example values in PGO/automatic dynamic

Run 1, Invocation 2: We detect varying inputs, do dynamic compile, get a dynamic graph and save to PGO. Crucially what we save to PGO is actually a superset of what is actually dynamic. If we notice an input was varying, we mark it as dynamic in PGO even if later on that value gets specialized. When a value gets specialized, we actually remove the symbol from the graph. This results in an interesting conundrum where although we are producing the same isomorphic graph, PGO makes the second run cache miss. Let's see how....

Run 2, Invocation 1: We fetch the PGO, over-mark things as dynamic, get a fx graph, look it up in the cache and... whoops! cache miss! This is because of the aforementioned behavior where the PGO profile will cause us to over-allocate symbols. In practice this means we end up saving a graph in cache with symbols x:s1, y:s3 and on second attempt we cache miss with x:s1, y:s6 where symbols s3,s4,s5 were all optimistically marked dynamic by PGO and subsequently specialized.

We solve this problem by hashing the source names. This ensures somewhat stable assignment. To prevent catastrophic symbol collisions, we use linear probing to ensure no collisions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149665
Approved by: https://github.com/Mingming-Ding, https://github.com/laithsakka
2025-03-27 03:39:27 +00:00
Ryan Guo
ae6158500a [dynamo] fix calling torch function on newly constructed tensor subclass (#149481)
This patch updates existing `test_return_..._subclass` tests in
`test/dynamo/test_subclasses.py`, so that they end up invoking the
`__torch_function__` method of the newly constructed tensor subclass
instnaces.

This exposes a bug in `TensorVariable.method_as_subclass`, where it
forgot to grab the `__func__` out of `__torch_function__`, which led to
the an error down the line.

This patch fixes `TensorVariable.method_as_subclass` by centralizing how
we extract and wrap torch function, in `build_torch_function_fn`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149481
Approved by: https://github.com/jansel
2025-03-24 21:07:41 +00:00
Michael Lazos
1d3c50fcc5 [Dynamo] Support the torch._C.DisableTorchFunction ctx manager (#149491)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149491
Approved by: https://github.com/StrongerXi
ghstack dependencies: #149489, #149490
2025-03-20 22:19:55 +00:00
Animesh Jain
eb9c127341 [dynamo][optimizers] Install ID_GUARDED tensors into the Fx graph (#147824)
Earlier, with inline flag we were lifting id-guarded tensors to the inputs to the Fx graph. But this offers no benefit. Main idea behind lifting parameters as inputs was to reuse the compilation units across many instances of the nn-module. However, if we are guarding on the `id`, we are explicitly specializing the compiled artifact to the parameter.

This PR installs the parameters back into the graph. The benefit is removal of all pre-graph bytecode to extract the id-guarded tensors from locals/globals. This increases speedup from 1.67x to 1.75x for an internal model that has large number of optimizer parameters.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147824
Approved by: https://github.com/jansel

Co-authored-by: Jason Ansel <jansel@meta.com>
2025-02-28 03:22:11 +00:00