Commit Graph

704 Commits

Author SHA1 Message Date
Zhengxu Chen
24aadb40fb [precompile] Serialization for GlobalStateGuard (#150636)
Summary: To preserve global state guards we need to make the C++ type serialzable. Using json because it's easier to do and we don't have a lot of data in global state.

Test Plan: test_dynamo -k test_global_state_guard_serialization

Differential Revision: D72410611

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150636
Approved by: https://github.com/williamwen42
2025-04-07 03:10:03 +00:00
Pian Pawakapan
c6d79c163c [dynamic shapes] allow duck typing for 0/1 (#150222)
Fixes #150184

e.g. for config.backed_size_oblivious=True and compile

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150222
Approved by: https://github.com/laithsakka
2025-04-04 03:24:46 +00:00
Aby Mathew C
7df6f930e8 Adapt test_misc.py for HPUs (#149499)
This PR is related to https://github.com/pytorch/pytorch/pull/145476 . That PR had two files (test_functions.py and test_misc.py) . test_functions was causing CI/rebase/merge issues and hence removed for now. This PR contains only test_misc.py.

This is a continuation of https://github.com/pytorch/pytorch/pull/144387 .

## MOTIVATION
We recently integrated support for Intel Gaudi devices (identified as 'hpu') into the common_device_type framework via the pull request at https://github.com/pytorch/pytorch/pull/126970. This integration allows tests to be automatically instantiated for Gaudi devices upon loading the relevant library. Building on this development, the current pull request extends the utility of these hooks by adapting selected CUDA tests to operate on Gaudi devices. Additionally, we have confirmed that these modifications do not interfere with the existing tests on CUDA devices.

Other accelerators can also extend the functionality by adding the device in the devices list. ( For eg: xpu )

## CHANGES
Create a separate class for test functions running on CUDA devices
Extend the functionality of these tests to include HPUs
Use instantiate_device_type_tests with targeted attributes to generate device-specific test instances within the new classes
Apply skipIfHPU decorator to bypass tests that are not yet compatible with HPU devices

PS: Most of these changes were initially part of https://github.com/pytorch/pytorch/pull/147609 , but closed that PR due to merge conflicts. The review comments were handled in this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149499
Approved by: https://github.com/EikanWang, https://github.com/desertfire, https://github.com/cyyever
2025-04-04 02:47:43 +00:00
Ryan Guo
33535b3eee [dynamo] Support Tensor subclass that has dynamic attributes or calls Parameter.__torch_function__ (#149482)
This fixes most of https://github.com/huggingface/diffusers/issues/10795,
except for `torch.Tensor._make_subclass`, which will be fixed in a
subsequent patch.

The relevant tensor subclass from the aforementioned issue is defined
here: fbf6b856cc/src/diffusers/quantizers/gguf/utils.py (L398-L435).

There are two things to note about the tensor subclass:
1. it calls `super().__torch_function__`, which is
   `torch._C._disabled_torch_function_impl`, so this patch updates
   `SuperVariable.call_method` to handle it (we can't do a simpler
   polyfill due to some bug with `var_getattr` raising
   `NotImplementedError`, which forgot to restore symbolic context).
2. it sets and reads attributes (`quant_type`), and
   defines new methods (`as_data`), so this patch adds support for those.
3. it has a `__init__`, which Dynamo needs to trace through in
   `TensorSubclassVariable.call_function`.

Differential Revision: [D71906140](https://our.internmc.facebook.com/intern/diff/D71906140)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149482
Approved by: https://github.com/jansel, https://github.com/mlazos
2025-04-02 20:56:43 +00:00
PyTorch MergeBot
03c879d59b Revert "[dynamo] Support Tensor subclass that has dynamic attributes or calls Parameter.__torch_function__ (#149482)"
This reverts commit 98453c135a.

Reverted https://github.com/pytorch/pytorch/pull/149482 on behalf of https://github.com/malfet due to Broke trunk, see b03c42109c/1 ([comment](https://github.com/pytorch/pytorch/pull/149482#issuecomment-2773650522))
2025-04-02 20:30:33 +00:00
Ryan Guo
98453c135a [dynamo] Support Tensor subclass that has dynamic attributes or calls Parameter.__torch_function__ (#149482)
This fixes most of https://github.com/huggingface/diffusers/issues/10795,
except for `torch.Tensor._make_subclass`, which will be fixed in a
subsequent patch.

The relevant tensor subclass from the aforementioned issue is defined
here: fbf6b856cc/src/diffusers/quantizers/gguf/utils.py (L398-L435).

There are two things to note about the tensor subclass:
1. it calls `super().__torch_function__`, which is
   `torch._C._disabled_torch_function_impl`, so this patch updates
   `SuperVariable.call_method` to handle it (we can't do a simpler
   polyfill due to some bug with `var_getattr` raising
   `NotImplementedError`, which forgot to restore symbolic context).
2. it sets and reads attributes (`quant_type`), and
   defines new methods (`as_data`), so this patch adds support for those.
3. it has a `__init__`, which Dynamo needs to trace through in
   `TensorSubclassVariable.call_function`.

Differential Revision: [D71906140](https://our.internmc.facebook.com/intern/diff/D71906140)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149482
Approved by: https://github.com/jansel, https://github.com/mlazos
2025-04-02 17:05:12 +00:00
bobrenjc93
f649ee73ce Use source hashing to generate consistent symbolic ids (#149665)
This PR was inspired by internal models that were cache missing due to PGO. At a high level the problem looks as follows

Run 1, Invocation 1: We do static compile, save some example values in PGO/automatic dynamic

Run 1, Invocation 2: We detect varying inputs, do dynamic compile, get a dynamic graph and save to PGO. Crucially what we save to PGO is actually a superset of what is actually dynamic. If we notice an input was varying, we mark it as dynamic in PGO even if later on that value gets specialized. When a value gets specialized, we actually remove the symbol from the graph. This results in an interesting conundrum where although we are producing the same isomorphic graph, PGO makes the second run cache miss. Let's see how....

Run 2, Invocation 1: We fetch the PGO, over-mark things as dynamic, get a fx graph, look it up in the cache and... whoops! cache miss! This is because of the aforementioned behavior where the PGO profile will cause us to over-allocate symbols. In practice this means we end up saving a graph in cache with symbols x:s1, y:s3 and on second attempt we cache miss with x:s1, y:s6 where symbols s3,s4,s5 were all optimistically marked dynamic by PGO and subsequently specialized.

We solve this problem by hashing the source names. This ensures somewhat stable assignment. To prevent catastrophic symbol collisions, we use linear probing to ensure no collisions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149665
Approved by: https://github.com/Mingming-Ding, https://github.com/laithsakka
2025-03-28 05:36:32 +00:00
PyTorch MergeBot
af7719a2fa Revert "Use source hashing to generate consistent symbolic ids (#149665)"
This reverts commit 1f92348dc6.

Reverted https://github.com/pytorch/pytorch/pull/149665 on behalf of https://github.com/malfet due to Broke trunk, see 6eb3c2e282/1 ([comment](https://github.com/pytorch/pytorch/pull/149665#issuecomment-2758578187))
2025-03-27 16:02:27 +00:00
bobrenjc93
1f92348dc6 Use source hashing to generate consistent symbolic ids (#149665)
This PR was inspired by internal models that were cache missing due to PGO. At a high level the problem looks as follows

Run 1, Invocation 1: We do static compile, save some example values in PGO/automatic dynamic

Run 1, Invocation 2: We detect varying inputs, do dynamic compile, get a dynamic graph and save to PGO. Crucially what we save to PGO is actually a superset of what is actually dynamic. If we notice an input was varying, we mark it as dynamic in PGO even if later on that value gets specialized. When a value gets specialized, we actually remove the symbol from the graph. This results in an interesting conundrum where although we are producing the same isomorphic graph, PGO makes the second run cache miss. Let's see how....

Run 2, Invocation 1: We fetch the PGO, over-mark things as dynamic, get a fx graph, look it up in the cache and... whoops! cache miss! This is because of the aforementioned behavior where the PGO profile will cause us to over-allocate symbols. In practice this means we end up saving a graph in cache with symbols x:s1, y:s3 and on second attempt we cache miss with x:s1, y:s6 where symbols s3,s4,s5 were all optimistically marked dynamic by PGO and subsequently specialized.

We solve this problem by hashing the source names. This ensures somewhat stable assignment. To prevent catastrophic symbol collisions, we use linear probing to ensure no collisions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149665
Approved by: https://github.com/Mingming-Ding, https://github.com/laithsakka
2025-03-27 03:39:27 +00:00
Ryan Guo
1c98dc3664 [dynamo] Fix handling of setattr with some tensor attributes (#149791)
We weren't handling `setattr(tensor_obj, "real", 42)` correctly, because
the attribute is a `GetSetDescriptorType` that has special setter logic.
See added test and comments for more explanations.

This patch makes it so that we graph break in those cases, rather than
resulting in silent incorrectness.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149791
Approved by: https://github.com/mlazos
ghstack dependencies: #149481
2025-03-25 18:57:56 +00:00
bobrenjc93
621c801f78 fix dynamic float when dynamic=True (#149564)
Fixes https://github.com/pytorch/pytorch/issues/149406#issuecomment-2738111733. Basically previously we would only make floats dynamic via automatic dynamic, now if you set dynamic=True, we will make the floats dynamic on the first compile.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149564
Approved by: https://github.com/laithsakka
2025-03-22 05:58:59 +00:00
Zhengxu Chen
f47aa08130 [export] Support python assertion with symints. (#149444)
Summary: This diff ports some technique from torch.fx symbolic trace to trace through Python asserts when we run into data dependent symbolic shape assertions, so that we can achieve the same effect as torch dynamo to automatically turn assert into torch.check()s.

Test Plan: buck test mode/opt caffe2/test:test_export -- -r test_python_asserts_with_sym_int
Differential Revision: D71425360

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149444
Approved by: https://github.com/tugsbayasgalan
2025-03-20 23:07:45 +00:00
Guilherme Leobas
44e6464914 Allow setting attribute to NestedUserFunctionVariable (#146505)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146505
Approved by: https://github.com/zou3519
2025-03-20 19:59:30 +00:00
George Wigley
96a6a71ac7 skip test_torch_dynamo_codegen_pow if CPU backend is not cpp (#146595)
The test asserts that `aten.pow` is not present in the generated kernel code. When using a CPU backend other than cpp, the kernel contains comments referencing the aten ops that produced the kernel in this case `aten.pow`.

This PR skips that test case if the CPU backend is not cpp.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146595
Approved by: https://github.com/williamwen42
2025-03-13 10:03:29 +00:00
Animesh Jain
f1787ee0f7 [dynamo] Remove L scoping for recompilation messages (#148917)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148917
Approved by: https://github.com/williamwen42
2025-03-11 14:26:26 +00:00
bobrenjc93
4708cfdbd9 Support whitelist of dynamic sources (#147979)
This PR introduces the ability to whitelist sources as dynamic. This is particularly useful for large models with graph breaks, as you can keep the dynamism across graph breaks since source names stay consistent. Additionally you can use this to mark ints as dynamic.

NB: I intentionally didn't complicate the interface by supporting specification of per dimension dynamism. There is virtue in keeping true to the standard way of representing sources (eg. L['x']). If we find in practice that we need more more fine grained control, we can explore further affordances at that time.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147979
Approved by: https://github.com/Mingming-Ding
2025-02-28 15:43:14 +00:00
bobrenjc93
0d56b7e665 Support size oblivious max equation (#147344)
Addresses https://github.com/pytorch/pytorch/issues/125914 by detecting when we have a sym_max between {0, 1} and a summation of size-like unbacked symints.

The basic idea is max(1, u0 + u1) can be simplified to u0 + u1 if both u0 and u1 are size-like since their value ranges are [2, inf].

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147344
Approved by: https://github.com/angelayi
2025-02-20 04:33:19 +00:00
William Wen
16e202a38e [dynamo] improved graph break messages for some common graph break sites [1/N] (#146525)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146525
Approved by: https://github.com/jansel
2025-02-20 00:08:13 +00:00
bobrenjc93
525ca80f53 add unbacked strict mode (#147333)
fixes #145775

This is the first step in introducing a "strict" mode where we don't silent specialize and don't silent graph break. At a high level when we do mark_unbacked(... strict=True), anytime we specialize an unbacked symint we will explicitly error and tell the user their unbacked dimension was specialized to a single value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147333
Approved by: https://github.com/laithsakka
2025-02-18 23:33:55 +00:00
bobrenjc93
5d547d82e6 Add no_data_dependent_graph_break mode (#147342)
This adds a strict mode `TORCHDYNAMO_UNBACKED_STRICT` to prevent graph breaking when we guard on data dependent. This is a better UX for those who are actively trying to make their model more dynamic, but aren't close enough to full graph to use that flag directly.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147342
Approved by: https://github.com/laithsakka
2025-02-18 23:33:47 +00:00
Guilherme Leobas
cefd9805de Add RAISE_VARARGS 0 (#146493)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146493
Approved by: https://github.com/zou3519
ghstack dependencies: #146498, #146492
2025-02-14 13:37:23 +00:00
Guilherme Leobas
6a9a02acbe Set enable_faithful_generator_behavior flag to True (#142513)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142513
Approved by: https://github.com/zou3519
ghstack dependencies: #141055, #144421, #144422, #144423, #144424, #144420, #145223
2025-02-08 22:42:12 +00:00
Animesh Jain
e2e265e27b [dynamo] Use polyfill to implement comparison operators (#144485)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144485
Approved by: https://github.com/jansel
2025-02-06 17:27:07 +00:00
Harmen Stoppels
01554c7b5a fix incorrect literal strings / accidental tuples (#146037)
* `expr,` is short for `(expr,)`
* literal strings over multiple lines need to escape the newline `\` or use `(...)`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146037
Approved by: https://github.com/Skylion007
2025-02-03 15:08:11 +00:00
Yanbo Liang
511d0dd558 [Dynamo][Trace PyDispatcher] Support calling id function over class (#146269)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146269
Approved by: https://github.com/anijain2305
2025-02-02 22:29:30 +00:00
Animesh Jain
cef856faa9 [dynamo][enum] Trace through enum.py for enum construction (#146070)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146070
Approved by: https://github.com/jansel
ghstack dependencies: #146062, #146198, #146258, #146214
2025-02-02 03:12:36 +00:00
Animesh Jain
31fb691782 [dynamo] Graph break on tensor.retain_grad (#146214)
Fixes https://github.com/pytorch/pytorch/issues/146212

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146214
Approved by: https://github.com/jansel
ghstack dependencies: #146062, #146198, #146258
2025-02-02 03:12:36 +00:00
bobrenjc93
30f091da44 add speculation log divergence test (#145659)
Followup from a SEV. Confirmed that this breaks when stacked on top of https://github.com/pytorch/pytorch/pull/145660 (offending PR that caused the SEV)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145659
Approved by: https://github.com/laithsakka
2025-02-01 09:39:22 +00:00
cyy
18380836eb Remove outdated test skipif conditions for Python3.9 (#146144)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146144
Approved by: https://github.com/albanD
2025-01-31 19:01:04 +00:00
rzou
2e5886dcc4 Add fake_impl for unique_consecutive (#145649)
Summary:
It's fairly similar to torch.unique and torch.unique_dim.

Test Plan:
New test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145649
Approved by: https://github.com/ezyang, https://github.com/eellison
2025-01-29 22:33:16 +00:00
PyTorch MergeBot
1185b81c51 Revert "[dynamo] Use polyfill to implement comparison operators (#144485)"
This reverts commit d1f82de2bf.

Reverted https://github.com/pytorch/pytorch/pull/144485 on behalf of https://github.com/huydhn due to This seems to break dynamo tests in trunk after landing ([comment](https://github.com/pytorch/pytorch/pull/144485#issuecomment-2622893294))
2025-01-29 21:30:42 +00:00
Animesh Jain
d1f82de2bf [dynamo] Use polyfill to implement comparison operators (#144485)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144485
Approved by: https://github.com/jansel
2025-01-29 17:37:40 +00:00
bobrenjc93
8696e59ae2 add test for capture_dynamic_output_shape_ops=True changing expected output between eager and compiled versions (#145821)
Followup from https://github.com/pytorch/pytorch/issues/130290

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145821
Approved by: https://github.com/eellison, https://github.com/ezyang
2025-01-29 04:36:32 +00:00
Ryan Guo
eaec97ab1f [dynamo] Properly prune dead input cell object (#145781)
This patch models input cell object as "newly created" rather than
"pre-existing" python object (see added documentation for why this
actually captures the semantics more accurately).

This enables the `SideEffects.prune_dead_object_new` algorithm to prune
away writes to input cell objects which are no longer relevant; this
didn't happen prior to this patch because we modelled them as
pre-existing objects, which forces us to codegen their attribute
mutations.

Fixes #145564.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145781
Approved by: https://github.com/williamwen42, https://github.com/jansel
2025-01-28 18:28:13 +00:00
bobrenjc93
6f07847efe Bail on checking internal overlap when dealing with unbacked symints (#145385)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145385
Approved by: https://github.com/ezyang
2025-01-23 22:31:31 +00:00
Aaron Orenstein
1ce533867f Teach dynamo to handle GenericAlias without a graph break (#145240)
Dynamo wasn't handling the new PEP585 type annotations:
```
x = list[Foo]
```
Although this worked in py3.9 this was causing an `unimplemented` (Unexpected type in sourceless builder) in py3.12.

This fixes it to treat them as a BuiltinVariable.

Fixes #145226

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145240
Approved by: https://github.com/anijain2305
2025-01-22 01:55:51 +00:00
Animesh Jain
8ccf3f6f3f [dynamo][easy] Move dict tests to test_dicts.py (#144165)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144165
Approved by: https://github.com/jansel
ghstack dependencies: #143997
2025-01-08 03:56:33 +00:00
Animesh Jain
2ac41404a8 [dynamo][dicts] Guarding lazily on dict keys (#143997)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143997
Approved by: https://github.com/jansel
2025-01-08 03:56:33 +00:00
Xiaodong Wang
3d3a07963f [reland][attempt2][AMD] Turn on TF32 for aten::mm (#144145)
Summary:
https://github.com/pytorch/pytorch/pull/143549 was reverted due to some
internal/oss tooling issue. Relanding.

hipblaslt supports TF32, so adding the support.
Original PR https://github.com/pytorch/pytorch/pull/139869

Test Plan: CI

Differential Revision: D67785496

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144145
Approved by: https://github.com/jianyuh
2025-01-06 00:37:01 +00:00
Animesh Jain
f6488d85a0 [dynamo][user-defined] Remove __getattribute__ checks and add getsetdescriptor (#144173)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144173
Approved by: https://github.com/jansel
2025-01-05 13:48:15 +00:00
PyTorch MergeBot
b01556bd8a Revert "[dynamo][dicts] Guarding lazily on dict keys (#143997)"
This reverts commit f5df082fab.

Reverted https://github.com/pytorch/pytorch/pull/143997 on behalf of https://github.com/jeanschmidt due to Seems to have introduced internal ci redness in some tests, D67828366 ([comment](https://github.com/pytorch/pytorch/pull/143997#issuecomment-2571587599))
2025-01-05 11:09:45 +00:00
Animesh Jain
f5df082fab [dynamo][dicts] Guarding lazily on dict keys (#143997)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143997
Approved by: https://github.com/jansel
ghstack dependencies: #144129, #144130, #144141, #144158, #144163, #144160
2025-01-04 18:13:00 +00:00
Animesh Jain
c5c897c3a1 [dynamo][easy] Miscellaneous fixes (#144141)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144141
Approved by: https://github.com/williamwen42
ghstack dependencies: #144129, #144130
2025-01-03 18:22:56 +00:00
Animesh Jain
a87cd5283b [dynamo] Trace through overridden __getattribute__ method (#143888)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143888
Approved by: https://github.com/jansel
2024-12-27 18:10:00 +00:00
Animesh Jain
e296bab614 [dynamo] Remove DICT_SUBCLASS_GUARD_MANAGER and use dict.keys (#143722)
In hinsight, we never needed a DICT_SUBCLASS_GUARD_MANAGER, because Dynamo would inline through the overridden keys method. In this PR, we ensure that while creating guards and constructing variable trackers, we get the `d.keys()` value by using `dict.keys(d)`. This ensures that we do not call overridden keys method. Therefore, the C++ guard can use `PyDict_Next` directly to check the guards.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143722
Approved by: https://github.com/jansel
2024-12-27 04:51:35 +00:00
PyTorch MergeBot
26364428f5 Revert "[dynamo] Remove DICT_SUBCLASS_GUARD_MANAGER and use dict.keys (#143722)"
This reverts commit fe95cbe018.

Reverted https://github.com/pytorch/pytorch/pull/143722 on behalf of https://github.com/wdvr due to failing internal tests ([comment](https://github.com/pytorch/pytorch/pull/143722#issuecomment-2563127017))
2024-12-26 22:04:36 +00:00
Animesh Jain
fe95cbe018 [dynamo] Remove DICT_SUBCLASS_GUARD_MANAGER and use dict.keys (#143722)
In hinsight, we never needed a DICT_SUBCLASS_GUARD_MANAGER, because Dynamo would inline through the overridden keys method. In this PR, we ensure that while creating guards and constructing variable trackers, we get the `d.keys()` value by using `dict.keys(d)`. This ensures that we do not call overridden keys method. Therefore, the C++ guard can use `PyDict_Next` directly to check the guards.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143722
Approved by: https://github.com/jansel
2024-12-24 02:00:18 +00:00
Oguz Ulgen
dc55704b48 Rename cache limit to recompile limit in configs (#143709)
This PR renames every cache_limit to recompile_limit via sed.

Old config options are maintained via Config(alias='xyz')

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143709
Approved by: https://github.com/jansel
2024-12-22 10:03:57 +00:00
Animesh Jain
4627cfd1f9 [dynamo] Support user defined dicts (#143548)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143548
Approved by: https://github.com/yanboliang, https://github.com/jansel, https://github.com/williamwen42
2024-12-21 01:46:14 +00:00
Guilherme Leobas
673cc88fd6 Add support for contextmanager in Dynamo (#136033)
Fixes #130559

* Intro

This PR adds support for `@contextmanager` in Dynamo. We chose to limit the
scope of this work to only `@contextmanager` and plan to handle generators fully
in #141055 (still in draft).

* Motivation

Dynamo lacks support for generator functions. When it encounters one, it traces
it as if it were a regular function. This is problematic because it can lead to
incorrect behavior. To illustrate, consider the test case below:

```python
import torch
import contextlib

@contextlib.contextmanager
def set_default_dtype(dtype):
    old_dtype = torch.get_default_dtype()
    try:
        torch.set_default_dtype(dtype)
        yield
    finally:
        torch.set_default_dtype(old_dtype)

@torch.compile(backend="eager", fullgraph=True)
def fn():
    with set_default_dtype(torch.float64):
        x = torch.tensor([3.0, 3.0 + 5.0j])
    return x
```

Before this work, Dynamo would not stop at the `yield`, and the graph produced
would contain both calls to `set_default_dtype` executed one after the other.
This is incorrect because the context manager should execute code before and
after the `yield`.

* List of changes

`YIELD_VALUE` now raises an exception (`YieldValueOp`) to signal that control
flow must be suspended and returned to the caller. Additionally, `RETURN_VALUE`
behaves differently in a generator function. Unlike regular functions, where
`RETURN_VALUE` indicates the final result, in generators it signifies that the
generator is exhausted and implicitly raises `StopIteration`.

A new `VariableTracker` named `FunctionDecoratedByContextlibContextManagerVariable`
was introduced to handle `@contextmanager`. This variable tracker acts not just
as a wrapper for the original function but also maintains an internal `tx`
(InstructionTranslator) object to suspend and return control flow to the parent
tracer when a `yield` is encountered.

* Corner cases

Returning a context manager from a compiled function is not supported. This
would require PyTorch to synchronize the generator state between Dynamo and the
interpreter. Any attempt to return it will result in an `IncorrectUsage`
exception.

Graph breaks require special handling as well. In the event of a graph break,
the frame associated with the context manager is skipped, and the context
manager runs in eager mode.

* This PR is breaking my code

There is a configuration flag (`enable_trace_contextlib`) that can be set to
`False` to disable tracing context managers. If this still causes crashes,
please revert this PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136033
Approved by: https://github.com/zou3519
2024-12-20 12:02:20 +00:00