Commit Graph

217 Commits

Author SHA1 Message Date
karthickai
10c3e6ec43 [inductor][dynamo] Include operator name in size/stride/alignment assertion (#152353)
Fixes #151930

This PR updates the `assert_size_stride` and `assert_alignment` functions in [guards.cpp](https://github.com/pytorch/pytorch/blob/main/torch/csrc/dynamo/guards.cpp) to accept an optional `op_name` argument and includes it in the error messages.

The corresponding type stubs in [guards.pyi](https://github.com/pytorch/pytorch/blob/main/torch/_C/_dynamo/guards.pyi) are updated to match the new function arg.

In [inductor/ir.py](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py) extracts the operator name from the FX graph and passes it into the `codegen_size_asserts` and `codegen_alignment_asserts` functions, so that generated assertions in Triton code include the op name for better debugging.

Added unit tests inside [test_torchinductor.py](https://github.com/pytorch/pytorch/blob/main/test/inductor/test_torchinductor.py).
- Verified both successful and failing assertion cases include the operator name.
- Verified that generated Triton code contains the op name inside the asserts.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152353
Approved by: https://github.com/jansel, https://github.com/shunting314
2025-06-03 19:21:15 +00:00
Animesh Jain
635b73e697 [dynamo][guards] Flush cache to more accurately measure guard overhead (#154764)
We observed that guard overhead at runtime using profiler traces was
higher than reported in this profiling function at the compile time.
After investigation, we found that f_locals are already in cache and
that was causing the guard overhead to be way smaller while profiling
during the compilation. To be more realistic, we flush the cache here.

Profiling the guard overhead during compilation (in addition to at
runtime) allows faster iteration time, and logging in tlparse and
internal databases.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154764
Approved by: https://github.com/zou3519, https://github.com/jansel, https://github.com/StrongerXi
2025-06-03 11:50:57 +00:00
PyTorch MergeBot
b86aaaae0b Revert "[dynamo][guards] Flush cache to more accurately measure guard overhead (#154764)"
This reverts commit 7dee899130.

Reverted https://github.com/pytorch/pytorch/pull/154764 on behalf of https://github.com/seemethere due to This fails internal tests see [fburl.com/diff/67gyp7gp](https://fburl.com/diff/67gyp7gp) ([comment](https://github.com/pytorch/pytorch/pull/154769#issuecomment-2933629894))
2025-06-03 06:13:49 +00:00
Animesh Jain
7dee899130 [dynamo][guards] Flush cache to more accurately measure guard overhead (#154764)
We observed that guard overhead at runtime using profiler traces was
higher than reported in this profiling function at the compile time.
After investigation, we found that f_locals are already in cache and
that was causing the guard overhead to be way smaller while profiling
during the compilation. To be more realistic, we flush the cache here.

Profiling the guard overhead during compilation (in addition to at
runtime) allows faster iteration time, and logging in tlparse and
internal databases.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154764
Approved by: https://github.com/zou3519, https://github.com/jansel, https://github.com/StrongerXi
ghstack dependencies: #154769
2025-06-02 23:01:58 +00:00
PyTorch MergeBot
4d073af58c Revert "[inductor][dynamo] Include operator name in size/stride/alignment assertion (#152353)"
This reverts commit 725bbb6b5f.

Reverted https://github.com/pytorch/pytorch/pull/152353 on behalf of https://github.com/jeanschmidt due to seems to have broken a few internal tests, @jansel may you help the author get his PR merged? ([comment](https://github.com/pytorch/pytorch/pull/152353#issuecomment-2885997862))
2025-05-16 08:20:39 +00:00
karthickai
725bbb6b5f [inductor][dynamo] Include operator name in size/stride/alignment assertion (#152353)
Fixes #151930

This PR updates the `assert_size_stride` and `assert_alignment` functions in [guards.cpp](https://github.com/pytorch/pytorch/blob/main/torch/csrc/dynamo/guards.cpp) to accept an optional `op_name` argument and includes it in the error messages.

The corresponding type stubs in [guards.pyi](https://github.com/pytorch/pytorch/blob/main/torch/_C/_dynamo/guards.pyi) are updated to match the new function arg.

In [inductor/ir.py](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py) extracts the operator name from the FX graph and passes it into the `codegen_size_asserts` and `codegen_alignment_asserts` functions, so that generated assertions in Triton code include the op name for better debugging.

Added unit tests inside [test_torchinductor.py](https://github.com/pytorch/pytorch/blob/main/test/inductor/test_torchinductor.py).
- Verified both successful and failing assertion cases include the operator name.
- Verified that generated Triton code contains the op name inside the asserts.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152353
Approved by: https://github.com/jansel
2025-05-15 02:33:57 +00:00
PyTorch MergeBot
fdc387ec7c Revert "refine fp32 precision api (#125888)"
This reverts commit 4c11b26158.

Reverted https://github.com/pytorch/pytorch/pull/125888 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to cause some failures on ROCm ([comment](https://github.com/pytorch/pytorch/pull/125888#issuecomment-2869274791))
2025-05-11 00:35:46 +00:00
haozhe.zhu
4c11b26158 refine fp32 precision api (#125888)
Based on the [conversation](https://github.com/pytorch/pytorch/issues/121791), we plan to drop the "highest, high, medium" to represent fp32  internal computation data types . Instead, we will directly use the algorithm to represent it.

### Design Choice: Directly use algorithms name like "TF32", "BF16".
#### Pros
 - The names are more informative. 'tf32' is more informative than a simple "high".
 - Easier to extend new algorithm like `tf32x3`
#### Cons
 - "HIGHEST, HIGH, MEDIUM" indicated the relative precision between different algorithms. However, we can have more documents to discuss them.

### We provide a layered structure for backends/operators.
('f32' is short for 'fp32_precision')
![image](https://github.com/user-attachments/assets/f89143e5-d6a1-4865-9351-9a50439f5067)

### We provide 3 fp32 compute precision can be set:
 - **"ieee"**: Not allowed to use any other internal computation data types .
 - **"tf32"**: Allowed to use tf32 as internal computation data types.
 - **"bf16"**: Allowed to use bf16 as internal computation data types.
 - **"none"**:  Precision's are not set. Can be override by its father node.

### Overriding Precision Settings
Child node can be override by its father node if it is set to default.
For current default settings:
```
backend = generic, op = all, precision setting = none
    backend = cuda, op = all, precision setting = none
        backend = cuda, op = conv, precision setting = tf32
        backend = cuda, op = rnn, precision setting = tf32
        backend = cuda, op = matmul, precision setting = none
    backend = matmul, op = all, precision setting = none
        backend = matmul, op = conv, precision setting = none
        backend = matmul, op = rnn, precision setting = none
        backend = matmul, op = matmul, precision setting = none
```
 - If the user set `torch.backends.mkldnn.fp32_precision="bf16"`, his child nodes `torch.backends.mkldnn.matmul.fp32_precision` / `torch.backends.mkldnn.conv.fp32_precision` / `torch.backends.mkldnn.rnn.fp32_precision` will also be override to "bf16".
 - If the user set `torch.backends.fp32_precision="bf16"`,  `torch.backends.mkldnn.fp32_precision` and his child nodes will also we override to "bf16".

### Backward Compatible
Since new API allow user to have more fine-grained control. There will be some conflict. For example, previous `torch.backends.cudnn.allow_tf32` are not enough to represent the status for `torch.backends.cudnn.rnn.fp32_precision="ieee"` and `torch.backends.cudnn.conv.fp32_precision="tf32"`. Therefore, our goal for backward compatible is
 - If the user only uses previous APIs, it will work as previous expectations.
 - If the user use **new** API to change the status to an **un-representable** status for old API, and try to access the status by **old** API. We will raise Runtime Error and point the document for user.

### Test Plan
```
python test/test_cuda.py -k test_fp32_precision_with_tf32
python test/test_cuda.py -k test_fp32_precision_with_float32_matmul_precision
python test/test_cuda.py -k test_invalid_status_for_legacy_api
python test/test_mkldnn.py -k test_mlkdnn_get_set
python test/test_mkldnn.py -k test_generic_precision
python test/test_mkldnn.py -k test_invalid
python test/test_mkldnn.py -k test_default_use_parent
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125888
Approved by: https://github.com/jgong5, https://github.com/albanD

Co-authored-by: Jiang, Yanbing <yanbing.jiang@intel.com>
2025-05-10 11:13:04 +00:00
PyTorch MergeBot
7b806a8cb1 Revert "[inductor][dynamo] Include operator name in size/stride/alignment assertion (#152353)"
This reverts commit 9357635127.

Reverted https://github.com/pytorch/pytorch/pull/152353 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to fail an inductor test in trunk ([comment](https://github.com/pytorch/pytorch/pull/152353#issuecomment-2863657185))
2025-05-08 16:39:28 +00:00
karthickai
9357635127 [inductor][dynamo] Include operator name in size/stride/alignment assertion (#152353)
Fixes #151930

This PR updates the `assert_size_stride` and `assert_alignment` functions in [guards.cpp](https://github.com/pytorch/pytorch/blob/main/torch/csrc/dynamo/guards.cpp) to accept an optional `op_name` argument and includes it in the error messages.

The corresponding type stubs in [guards.pyi](https://github.com/pytorch/pytorch/blob/main/torch/_C/_dynamo/guards.pyi) are updated to match the new function arg.

In [inductor/ir.py](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py) extracts the operator name from the FX graph and passes it into the `codegen_size_asserts` and `codegen_alignment_asserts` functions, so that generated assertions in Triton code include the op name for better debugging.

Added unit tests inside [test_torchinductor.py](https://github.com/pytorch/pytorch/blob/main/test/inductor/test_torchinductor.py).
- Verified both successful and failing assertion cases include the operator name.
- Verified that generated Triton code contains the op name inside the asserts.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152353
Approved by: https://github.com/jansel, https://github.com/shunting314
2025-05-08 08:28:05 +00:00
Zhengxu Chen
203201255f [dynamo] remove dead code for DATA_PTR_MATCH (#152206)
Summary: Seems this guard is not created anywhere

Test Plan: CI

Differential Revision: D73682084

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152206
Approved by: https://github.com/anijain2305, https://github.com/jansel
2025-04-26 15:25:01 +00:00
zhxchen17
a34c28e0d2 [dynamo] Add guard serialization for tensor matches. (#151318)
This is a proof-of-concept of how we could serialize a guard and deserialize it back from the bytes.

The main behavioral change introduced in this diff is on CheckFunctionManager:

```
check_fn_manager = CheckFunctionManager(code, output_graph, guards_serialization_mode="save")

guards_state: bytes = check_fn_manager.guards_state
```

Once `guards_serialization_mode` is set to `save`, CheckFunctionManager will return an addtional `bytes` object called `guards_state` which should contain all the information needed for deserializing guards later.

When we load back guards state, we will set `guards_serialization_mode` is set to `load`:

```
output_graph_state = pickle.loads(guards_state)
check_fn_manager = CheckFunctionManager(code, output_graph_state, guards_serialization_mode="load")
```

# TENSOR_MATCH

Since we have many types of guards to support, we will break the work into small diffs instead of a single diff to support every guards.

We kick off the work from TENSOR_MATCH from this diff.

# Testing

For each type of guard we will test it like the following:
1. Use guard_filter_fn to select 1 type of guard each time.
2. Call InstructionTranslator directly on an example function to get OutputGraph and CheckFunctionManager (reference guard manager)
3. Serialize->deserialize the output graph state and re-build the guards with a new CheckFunctionManager (loaded guard manager)
4. Throw a set of example inputs to both reference and loaded guard manager to see if their behavior match.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151318
Approved by: https://github.com/jansel, https://github.com/anijain2305
2025-04-25 14:16:23 +00:00
PyTorch MergeBot
b1d055fd6a Revert "[dynamo] Add guard serialization for tensor matches. (#151318)"
This reverts commit 81c4369d81.

Reverted https://github.com/pytorch/pytorch/pull/151318 on behalf of https://github.com/zhxchen17 due to macos test failing ([comment](https://github.com/pytorch/pytorch/pull/151318#issuecomment-2828638168))
2025-04-24 19:22:45 +00:00
zhxchen17
81c4369d81 [dynamo] Add guard serialization for tensor matches. (#151318)
This is a proof-of-concept of how we could serialize a guard and deserialize it back from the bytes.

The main behavioral change introduced in this diff is on CheckFunctionManager:

```
check_fn_manager = CheckFunctionManager(code, output_graph, guards_serialization_mode="save")

guards_state: bytes = check_fn_manager.guards_state
```

Once `guards_serialization_mode` is set to `save`, CheckFunctionManager will return an addtional `bytes` object called `guards_state` which should contain all the information needed for deserializing guards later.

When we load back guards state, we will set `guards_serialization_mode` is set to `load`:

```
output_graph_state = pickle.loads(guards_state)
check_fn_manager = CheckFunctionManager(code, output_graph_state, guards_serialization_mode="load")
```

# TENSOR_MATCH

Since we have many types of guards to support, we will break the work into small diffs instead of a single diff to support every guards.

We kick off the work from TENSOR_MATCH from this diff.

# Testing

For each type of guard we will test it like the following:
1. Use guard_filter_fn to select 1 type of guard each time.
2. Call InstructionTranslator directly on an example function to get OutputGraph and CheckFunctionManager (reference guard manager)
3. Serialize->deserialize the output graph state and re-build the guards with a new CheckFunctionManager (loaded guard manager)
4. Throw a set of example inputs to both reference and loaded guard manager to see if their behavior match.

Differential Revision: [D72987485](https://our.internmc.facebook.com/intern/diff/D72987485/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151318
Approved by: https://github.com/jansel, https://github.com/anijain2305
2025-04-24 18:07:01 +00:00
Animesh Jain
7b1a2373e8 [dynamo][super variable] Fix bug to use correct source (#151154)
Fixes https://github.com/pytorch/pytorch/issues/150994

We should cherry-pick to 2.7 branch if possible, because this breaks torch.compile on some HF models. Look at the issue referenced here.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151154
Approved by: https://github.com/jansel
2025-04-13 04:48:52 +00:00
Bartlomiej Stemborowski
12281f9c18 [dynamo] Deprecate enable_cpp_framelocals_guard_eval config variable - default: True (#151008)
[dynamo] Deprecate enable_cpp_framelocals_guard_eval config variable - default: True

Reading the feature enabling param `enable_cpp_framelocals_guard_eval `at the CPP level is time consuming and slows down the operation of the dynamo as it is done every time the function using this param is called. Reading the value only once at init isn’t an option as it would disable the modification of this param at the runtime. Since this feature is enabled by default for some time and it doesn’t cause known issues, the `enable_cpp_framelocals_guard_eval `configuration param will be deprecated by this commit and its value is hardcoded to true.

Local microbenchmark dynamo_guard_eval.py:
- 931.9 us -> 538.9 us (3.10)

@williamwen42 @jansel @anijain2305

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151008
Approved by: https://github.com/williamwen42
2025-04-11 21:07:59 +00:00
Isuru Fernando
a22d3e778e [dynamo][guards] Print relational guards only once (#150810)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150810
Approved by: https://github.com/anijain2305
2025-04-11 04:10:37 +00:00
Shunting Zhang
252029b294 [Inductor] assert fallback output alignment (#150804)
Previous PR (https://github.com/pytorch/pytorch/pull/150777) fixes the alignment problem for fallback kernel assuming meta kernel is correct. This PR handles the case that meta kernel is incorrect. Assertion is added if the compiler assumes a  fallback kernel output is aligned.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150804
Approved by: https://github.com/jansel, https://github.com/eellison
ghstack dependencies: #150777
2025-04-10 20:01:06 +00:00
Zhengxu Chen
24aadb40fb [precompile] Serialization for GlobalStateGuard (#150636)
Summary: To preserve global state guards we need to make the C++ type serialzable. Using json because it's easier to do and we don't have a lot of data in global state.

Test Plan: test_dynamo -k test_global_state_guard_serialization

Differential Revision: D72410611

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150636
Approved by: https://github.com/williamwen42
2025-04-07 03:10:03 +00:00
Isuru Fernando
ff783f062a Fix shape guard failure to be valid python (#149149)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149149
Approved by: https://github.com/anijain2305
2025-04-03 14:36:17 +00:00
cyy
e872c38eb3 Remove cppcoreguidelines-pro-type-member-init_fix suppression (#148638)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148638
Approved by: https://github.com/zou3519
2025-04-02 01:33:20 +00:00
William Wen
a66a9581da [dynamo] support Python 3.13t (#149549)
A few bug fixes to get Dynamo mostly working with 3.13 nogil. Dynamo encounters internal CPython assert errors in older versions of 3.13. The fix has been landed on [CPython's 3.13 branch](https://github.com/python/cpython/tree/3.13) and will be included in 3.13.3 (https://peps.python.org/pep-0719/ - april 8). If you wish to try `torch.compile` on the latest 3.13 branch, you can comment out the error checking (i.e. 70b6cd4e11/torch/__init__.py (L2535) and 70b6cd4e11/torch/_dynamo/eval_frame.py (L899)).

We will work on getting PyTorch CI up for Dynamo/dynamo-wrapped/inductor once 3.13.3 is available.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149549
Approved by: https://github.com/jansel
2025-03-20 09:49:27 +00:00
Animesh Jain
f9a787224c [dynamo][guards][serialization] Dont use ID_MATCH guard for bool and None (#149228)
Doing this removes the need of collecting `id` and therefore facilitates serialization. It also improves readability with recompilations. Earlier, recompile message will just show the `id`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149228
Approved by: https://github.com/jansel
2025-03-18 01:25:37 +00:00
PyTorch MergeBot
b52a8bef01 Revert "[dynamo][guards][serialization] Dont use ID_MATCH guard for bool and None (#149228)"
This reverts commit 5905bbe745.

Reverted https://github.com/pytorch/pytorch/pull/149228 on behalf of https://github.com/malfet due to I wonder if this will fix the pr-time-benchmark regressions ([comment](https://github.com/pytorch/pytorch/pull/149228#issuecomment-2731237949))
2025-03-18 00:10:50 +00:00
Animesh Jain
5905bbe745 [dynamo][guards][serialization] Dont use ID_MATCH guard for bool and None (#149228)
Doing this removes the need of collecting `id` and therefore facilitates serialization. It also improves readability with recompilations. Earlier, recompile message will just show the `id`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149228
Approved by: https://github.com/jansel
2025-03-16 15:56:17 +00:00
Animesh Jain
992838e702 [dynamo][guards] Do not ID_MATCH on numpy tensors (#148923)
Might help with https://github.com/pytorch/pytorch/issues/148535

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148923
Approved by: https://github.com/jansel
2025-03-11 14:20:26 +00:00
cyy
f7c0c230b0 Fix compile errors (#148758)
Fix
```
  /usr/bin/../lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1/bits/unique_ptr.h:91:16: error: invalid application of 'sizeof' to an incomplete type 'torch::jit::AliasDb::WriteRegistry'
     91 |         static_assert(sizeof(_Tp)>0,
        |                       ^~~~~~~~~~~
  /usr/bin/../lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1/bits/unique_ptr.h:399:4: note: in instantiation of member function 'std::default_delete<torch::jit::AliasDb::WriteRegistry>::operator()' requested here
    399 |           get_deleter()(std::move(__ptr));
        |           ^
  ../torch/csrc/jit/ir/alias_analysis.cpp:200:10: note: in instantiation of member function 'std::unique_ptr<torch::jit::AliasDb::WriteRegistry>::~unique_ptr' requested here
    200 | AliasDb::~AliasDb() = default;
        |          ^
  ../torch/csrc/jit/ir/alias_analysis.cpp:200:23: note: in defaulted destructor for 'torch::jit::AliasDb' first required here
    200 | AliasDb::~AliasDb() = default;
        |                       ^
  ../torch/csrc/jit/ir/alias_analysis.h:298:10: note: forward declaration of 'torch::jit::AliasDb::WriteRegistry'
    298 |   struct WriteRegistry;
        |          ^
  1 error generated.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148758
Approved by: https://github.com/Skylion007
2025-03-08 04:56:42 +00:00
Animesh Jain
713a504a82 [dynamo][guards] Fix mem leak caused be refcount increment (#148480)
Should help [internalfb.com/sevmanager/view/491701](https://www.internalfb.com/sevmanager/view/491701)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148480
Approved by: https://github.com/xmfan, https://github.com/StrongerXi, https://github.com/williamwen42, https://github.com/zou3519
2025-03-05 01:04:08 +00:00
Animesh Jain
276dfe8150 [dynamo][cpp-guards] Disable dict-tag optim if the guard_manager has child accessors (#147694)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147694
Approved by: https://github.com/isuruf
2025-02-26 00:02:08 +00:00
Animesh Jain
9dc702875d [dynamo][mappingproxy][inspect] Support existing types.MappingProxyType (#147217)
Fixes https://github.com/pytorch/pytorch/issues/147162

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147217
Approved by: https://github.com/williamwen42
2025-02-15 07:59:33 +00:00
Shunting Zhang
bc0191802f [inductor] add size-asserts for fallback ops (#145904)
Fix https://github.com/pytorch/pytorch/issues/144717

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145904
Approved by: https://github.com/jansel
2025-02-07 18:44:32 +00:00
Isuru Fernando
08b14936ae Disable has_relational_guards check for dict_tag optimization for now (#146232)
has_relational_guards evaluates to true almost always, and leads to a
slowdown in guards runtime

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146232
Approved by: https://github.com/anijain2305
2025-02-03 07:56:06 +00:00
wengshiy
73622fc5fa Fix Throughputbenchmark issue (#144669)
Fixes [144461](https://github.com/pytorch/pytorch/issues/144461)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144669
Approved by: https://github.com/leslie-fang-intel, https://github.com/williamwen42, https://github.com/jansel
2025-01-26 03:37:20 +00:00
Animesh Jain
015c6d6fdb [dynamo][guards] Turn on profiling of guard manager (#145420)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145420
Approved by: https://github.com/ezyang
ghstack dependencies: #145351
2025-01-23 18:17:43 +00:00
Isuru Fernando
0efa843392 Dynamic shape guards in C++ (#139899)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139899
Approved by: https://github.com/anijain2305, https://github.com/albanD, https://github.com/jansel
ghstack dependencies: #143385, #143164
2025-01-22 14:58:35 +00:00
Yanbo Liang
5d02575aa1 [Trace Python dispatcher] Support torch.DispatchKey & torch.DispatchKeySet (#144439)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144439
Approved by: https://github.com/zou3519
2025-01-17 02:26:36 +00:00
cyy
2ea394ba29 Modernize C++ code (#144603)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144603
Approved by: https://github.com/malfet
2025-01-17 00:25:18 +00:00
Animesh Jain
732359c633 [dynamo][easy] Minor fixes in guards.cpp (#144130)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144130
Approved by: https://github.com/williamwen42
ghstack dependencies: #144129
2025-01-03 18:22:56 +00:00
cyy
8df99b6a6e Remove unneeded std::make_optional (#143575)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143575
Approved by: https://github.com/Skylion007
2024-12-31 03:08:47 +00:00
cyy
dca443835e Enable more readability-redundant checks (#143963)
They are helpful to simplifying code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143963
Approved by: https://github.com/albanD
2024-12-30 14:49:33 +00:00
Animesh Jain
e296bab614 [dynamo] Remove DICT_SUBCLASS_GUARD_MANAGER and use dict.keys (#143722)
In hinsight, we never needed a DICT_SUBCLASS_GUARD_MANAGER, because Dynamo would inline through the overridden keys method. In this PR, we ensure that while creating guards and constructing variable trackers, we get the `d.keys()` value by using `dict.keys(d)`. This ensures that we do not call overridden keys method. Therefore, the C++ guard can use `PyDict_Next` directly to check the guards.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143722
Approved by: https://github.com/jansel
2024-12-27 04:51:35 +00:00
PyTorch MergeBot
26364428f5 Revert "[dynamo] Remove DICT_SUBCLASS_GUARD_MANAGER and use dict.keys (#143722)"
This reverts commit fe95cbe018.

Reverted https://github.com/pytorch/pytorch/pull/143722 on behalf of https://github.com/wdvr due to failing internal tests ([comment](https://github.com/pytorch/pytorch/pull/143722#issuecomment-2563127017))
2024-12-26 22:04:36 +00:00
Jason Ansel
b0c3f48a40 [inductor] Improve error message for assert_size_stride (#143765)
```
>>> torch._C._dynamo.guards.assert_size_stride(torch.randn(10), (10,), (2,))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AssertionError: expected size 10==10, stride 1==2 at dim=0
This error most often comes from an incorrect meta function for a custom op.
See https://pytorch.org/docs/stable/library.html#torch.library.opcheck
>>>
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143765
Approved by: https://github.com/zou3519
2024-12-24 05:26:05 +00:00
cyy
1feae27ed6 [16/N] Fix extra warnings brought by clang-tidy-17 (#143714)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143714
Approved by: https://github.com/Skylion007, https://github.com/albanD
2024-12-24 03:29:38 +00:00
Animesh Jain
fe95cbe018 [dynamo] Remove DICT_SUBCLASS_GUARD_MANAGER and use dict.keys (#143722)
In hinsight, we never needed a DICT_SUBCLASS_GUARD_MANAGER, because Dynamo would inline through the overridden keys method. In this PR, we ensure that while creating guards and constructing variable trackers, we get the `d.keys()` value by using `dict.keys(d)`. This ensures that we do not call overridden keys method. Therefore, the C++ guard can use `PyDict_Next` directly to check the guards.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143722
Approved by: https://github.com/jansel
2024-12-24 02:00:18 +00:00
William Wen
7ab880bc5e fix typo in autocast header (#143625)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143625
Approved by: https://github.com/mlazos
ghstack dependencies: #143592
2024-12-20 16:17:15 +00:00
William Wen
1c2593f035 [dynamo] guard global autocast state (#143592)
Fixes https://github.com/pytorch/pytorch/issues/112260.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143592
Approved by: https://github.com/jansel
2024-12-20 03:30:54 +00:00
Animesh Jain
465f282a24 [reland][dynamo][guards] Consider tensors as immutable for dict tag matches (#141085)
Reland - https://github.com/pytorch/pytorch/pull/139560

As mentioned in https://github.com/pytorch/pytorch/pull/130341, using `static py::object` can lead to segfaults. I suspect this is the reason for the import system error seen internally (https://www.internalfb.com/sevmanager/view/469592). In this PR, I am removing the `static` part. This is fine and also the right thing to do because this will catch if user changes the flag in the same process for compiling two different functions.

Unfortunately, there is no easy way to trigger this segfault, so I can't write a test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141085
Approved by: https://github.com/jansel

Co-authored-by: William Wen <williamwen@meta.com>
2024-12-19 15:16:10 +00:00
William Wen
18261e9f39 [dynamo] implement framelocals mapping as c++ object (#140063)
Implements https://github.com/pytorch/pytorch/issues/93753 - move frame local guard accessors to C++.

Before, we used dict accessors on a Python dict representing the frame's fastlocals that we manually build. We move this accessor to C++ and additionally use the fastlocal index whenever possible.

Some implementation notes:
- `FrameLocalsMapping` is now initialized as a C++ vector of `PyObject`s. We do not just use the frame's localsplus/fastlocals buffer because we also unbox cells.
- `FrameLocalsMapping` can still be converted into a Python dict representing the frame's fastlocals, but it is done lazily.
- We update `LeafGuard`, `GuardAccessor`, and `GuardManager`'s `check_nopybind` methods to accept `FrameLocalsMapping`. By default, we convert the `FrameLocalsMapping` to a Python dict and run the original `check_nopybind` on it, but in some cases, conversion is not needed.
- We add a new guard accessor `FrameLocalsGuardAccessor`, which is similar to `DictGetItemGuardAccessor` but has special handling for `FrameLocalsMapping`. We create a separate class to emphasize different use cases, but we could probably combine these two (can do in a follow up)

dynamo_guard_eval.py microbenchmark update:
- 713.2us -> 630.0us (3.10)
- 598.8us -> 530.7us (3.12)

Other followups:
- Add `FrameLocalsMapping` version for `check_verbose_nopybind` in order to match behavior between `check_nopybind` and `check_verbose_nopybind`. This can prevent difficult debugging situations where guards fail (`check_nopybind` returns false) but no guard error message is generated (`check_verbose_nopybind` succeeds).
- Rewrite the `SHAPE_ENV` guard into C++ - it is a fairly common guard that results in `FrameLocalsMapping` needing to convert to a dict

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140063
Approved by: https://github.com/jansel
ghstack dependencies: #142117, #142430
2024-12-17 18:54:27 +00:00
PyTorch MergeBot
e3d754419f Revert "[reland][dynamo][guards] Consider tensors as immutable for dict tag matches (#141085)"
This reverts commit 1bf983077f.

Reverted https://github.com/pytorch/pytorch/pull/141085 on behalf of https://github.com/huydhn due to The diff D66211131 has been commandeered internally and is it not part of the train anymore.  If codev is needed, pls reland this accordingly ([comment](https://github.com/pytorch/pytorch/pull/141085#issuecomment-2549092225))
2024-12-17 17:21:14 +00:00