Commit Graph

37 Commits

Author SHA1 Message Date
IvanKobzarev
4439255148 [aotd] Support saved tensors hooks in aot_autograd (#150032)
https://github.com/pytorch/pytorch/issues/148222

Goal:

At the moment autograd saved tensors hooks are run in eager after compiled forward.
They are executed at the same time for all saved tensors.
Hooks can be used to reduce amout of memory used for saved tensors, doing quantization or offloading to cpu.
This is suboptimal for optimization of peak memory.
Better solution will be to put the hooks in the graph, as close as possible to the last usage of the tensor.

To get user specified autograd saved tensors hooks in the graph.

Logic:

UX:
If user specifies with torch.autograd.graph.saved_tensors_hooks(pack_gm, unpack_gm).
Where pack_gm and unpack_gm are torch.fx.GraphModule.
Then AotAutograd will retrace those graph modules, doing decompositions and functionalization in aot_autograd, inlining the result graphs in forward epilogue and backward prologue.

User may want to use control logic in the hooks, for example applying quantization only for specific dtypes and sizes.

This is also possible, user can put it into torch.fx.wrap function and use symbolic trace to make a GraphModule.

In that case AotAutograd cahing will work only in case when user explicitly set to the torch.fx.wrap call_function node "user_cache_hash" metadata.

If this metadata set - then aot_autograd cache can use saved cache artifact.
If metadata is not set - then cache is bypassed.

Dynamo:
Dynamo traces pack and unpack hooks and installs them as subgraph and explicitly adds to the output_graph. (As those subgraphs are not used and will not be copied in the result by default).

The complexity here is that at this moment we do not have example of inputs for the hooks.
We trace  pack_hook with some Tensor from the inputs.
The result subgraphs are added to the hashing of AotAutograd Cache.

In AotAutograd we retrace the graph with the true saved tensors coming from partitioner.

Backwards Compatibility:
As current hooks are executed in eager mode and not all of them will be traceable - we only try to put in the graph hooks, explicitly marked by user with annotation (@_inlineable_saved_tensors_hooks).
For other hooks or if compiled autograd is enabled - keep the same logic.

Recompilations:
Hooks are guarded with lambda guard matching function id to cause recompilation if user reruns compiled function.

Aot_autograd:
After partitioner prepared forward and backward module - we trace prepared at Dynamo graphs for pack and unpack hooks and inline them in epilogue of forward and prologue of backward. Forward outputs and backward inputs are changed, transparently for user.

We do not try to put it close the last usage etc., relying on inductor to do this optimization.

```
INFO: TRACED GRAPH
 ===== Forward graph pre saved_tensors_hooks inlining 3 =====
 /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"):
         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1
        add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1);  primals_3 = None

         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
        view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2])
        return (view, add, primals_1, primals_2)

INFO: TRACED GRAPH
 ===== Backward graph pre saved_tensors_hooks inlining 3 =====
 /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"):
         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1
        add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1);  primals_3 = None

         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
        view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2])
        return (view, add, primals_1, primals_2)

INFO: TRACED GRAPH
 ===== saved_tensors_pack_hook add 3 =====
 /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class pack_float8(torch.nn.Module):
    def forward(self, x_1: "f32[s0, s1][s1, 1]cuda:0"):
        # No stacktrace found for following nodes
        _to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(x_1, dtype = torch.float8_e4m3fn);  x_1 = None
        return (torch.float32, _to_copy)

INFO: TRACED GRAPH
 ===== saved_tensors_unpack_hook add 3 =====
 <eval_with_key>.22 from /data/users/ivankobzarev/a/pytorch/torch/fx/experimental/proxy_tensor.py:1225 in wrapped class pack_float8(torch.nn.Module):
    def forward(self, x_1: "f32[s0, s1][s1, 1]cuda:0"):
        # No stacktrace found for following nodes
        _to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(x_1, dtype = torch.float8_e4m3fn);  x_1 = None
        return (torch.float32, _to_copy)

INFO: TRACED GRAPH
 ===== Forward graph 3 =====
 /data/users/ivankobzarev/a/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", primals_3: "f32[s0, s1][s1, 1]cuda:0"):
         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6660 in simple_fn, code: x = x + 1
        add: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(primals_3, 1);  primals_3 = None

        # No stacktrace found for following nodes
        _to_copy: "f8e4m3fn[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(add, dtype = torch.float8_e4m3fn)

         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
        view: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.view.default(add, [primals_1, primals_2]);  add = None
        return (view, _to_copy, primals_1, primals_2)

INFO: TRACED GRAPH
 ===== Backward graph 3 =====
 <eval_with_key>.21 class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "Sym(s0)", primals_2: "Sym(s1)", add_packed_2: "f8e4m3fn[s0, s1][s1, 1]cuda:0", tangents_1: "f32[s0, s1][s1, 1]cuda:0"):
        # No stacktrace found for following nodes
        _to_copy: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten._to_copy.default(add_packed_2, dtype = torch.float32);  add_packed_2 = None

         # File: /data/users/ivankobzarev/a/pytorch/test/functorch/test_aotdispatch.py:6661 in simple_fn, code: x = SAF.apply(x)
        add_7: "f32[s0, s1][s1, 1]cuda:0" = torch.ops.aten.add.Tensor(tangents_1, _to_copy);  tangents_1 = _to_copy = None
        return (None, None, add_7)

```

Differential Revision: [D72187044](https://our.internmc.facebook.com/intern/diff/D72187044)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150032
Approved by: https://github.com/bdhirsh
2025-05-22 14:09:38 +00:00
bobrenjc93
f649ee73ce Use source hashing to generate consistent symbolic ids (#149665)
This PR was inspired by internal models that were cache missing due to PGO. At a high level the problem looks as follows

Run 1, Invocation 1: We do static compile, save some example values in PGO/automatic dynamic

Run 1, Invocation 2: We detect varying inputs, do dynamic compile, get a dynamic graph and save to PGO. Crucially what we save to PGO is actually a superset of what is actually dynamic. If we notice an input was varying, we mark it as dynamic in PGO even if later on that value gets specialized. When a value gets specialized, we actually remove the symbol from the graph. This results in an interesting conundrum where although we are producing the same isomorphic graph, PGO makes the second run cache miss. Let's see how....

Run 2, Invocation 1: We fetch the PGO, over-mark things as dynamic, get a fx graph, look it up in the cache and... whoops! cache miss! This is because of the aforementioned behavior where the PGO profile will cause us to over-allocate symbols. In practice this means we end up saving a graph in cache with symbols x:s1, y:s3 and on second attempt we cache miss with x:s1, y:s6 where symbols s3,s4,s5 were all optimistically marked dynamic by PGO and subsequently specialized.

We solve this problem by hashing the source names. This ensures somewhat stable assignment. To prevent catastrophic symbol collisions, we use linear probing to ensure no collisions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149665
Approved by: https://github.com/Mingming-Ding, https://github.com/laithsakka
2025-03-28 05:36:32 +00:00
PyTorch MergeBot
af7719a2fa Revert "Use source hashing to generate consistent symbolic ids (#149665)"
This reverts commit 1f92348dc6.

Reverted https://github.com/pytorch/pytorch/pull/149665 on behalf of https://github.com/malfet due to Broke trunk, see 6eb3c2e282/1 ([comment](https://github.com/pytorch/pytorch/pull/149665#issuecomment-2758578187))
2025-03-27 16:02:27 +00:00
bobrenjc93
1f92348dc6 Use source hashing to generate consistent symbolic ids (#149665)
This PR was inspired by internal models that were cache missing due to PGO. At a high level the problem looks as follows

Run 1, Invocation 1: We do static compile, save some example values in PGO/automatic dynamic

Run 1, Invocation 2: We detect varying inputs, do dynamic compile, get a dynamic graph and save to PGO. Crucially what we save to PGO is actually a superset of what is actually dynamic. If we notice an input was varying, we mark it as dynamic in PGO even if later on that value gets specialized. When a value gets specialized, we actually remove the symbol from the graph. This results in an interesting conundrum where although we are producing the same isomorphic graph, PGO makes the second run cache miss. Let's see how....

Run 2, Invocation 1: We fetch the PGO, over-mark things as dynamic, get a fx graph, look it up in the cache and... whoops! cache miss! This is because of the aforementioned behavior where the PGO profile will cause us to over-allocate symbols. In practice this means we end up saving a graph in cache with symbols x:s1, y:s3 and on second attempt we cache miss with x:s1, y:s6 where symbols s3,s4,s5 were all optimistically marked dynamic by PGO and subsequently specialized.

We solve this problem by hashing the source names. This ensures somewhat stable assignment. To prevent catastrophic symbol collisions, we use linear probing to ensure no collisions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149665
Approved by: https://github.com/Mingming-Ding, https://github.com/laithsakka
2025-03-27 03:39:27 +00:00
Animesh Jain
d6513f3246 [dynamo] Support list subclasses and fix dict subclasses mutation bugs (#146819)
This PR adds support for list subclasses. Among other things are

1) Tracking the mutations on internal vts like `_dict_vt` and `_list_vt` using sources. This helps identify if there was a mutation in the underlying data structures, and we need to reconstruct it.
2) `UserDefinedObjectVariable` now has a new method - `is_modified` which `side_effect` infra relies upon to check mutations in the underlying vts (like `_dict_vt`).
3) `reconstruction` logic ensures that we use `dict.__getitem__` and `list.__getitem__` methods. This is super important because we don't want to call the overridden `__getitem__` methods.

If this PR is hard to review, please let me know. I can break it into several small PRs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146819
Approved by: https://github.com/StrongerXi, https://github.com/jansel
2025-02-12 17:46:02 +00:00
Tom Ritchford
d25e6e623f Fix unused Python variables in test/[a-d]* (#134665)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134665
Approved by: https://github.com/albanD
2024-12-13 22:13:12 +00:00
Yuanhao Ji
a1327fac45 [Dynamo] Replace torch._dynamo.optimize() with torch.compile() [5/N] (#140663)
related commits:

- #139706
- #140238
- #140247
- #140253
- #140663
- #140688

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140663
Approved by: https://github.com/williamwen42
2024-11-18 04:11:56 +00:00
Yuanhao Ji
d6b3ad4de2 [Dynamo] Replace torch._dynamo.optimize() with torch.compile() [2/N] (#140238)
related commits:

- #139706
- #140238
- #140247
- #140253

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140238
Approved by: https://github.com/soulitzer
2024-11-13 05:13:39 +00:00
Edward Z. Yang
c7ca89a11a Improve print stack/locals printing in comptime (#133651)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133651
Approved by: https://github.com/anijain2305
2024-08-27 01:29:50 +00:00
Oguz Ulgen
6f51782a59 Add comptime.sleep (#133362)
Add comp time sleep for NCCL timeout testing. The unit test is not great..

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133362
Approved by: https://github.com/jamesjwu
2024-08-14 19:32:18 +00:00
YangQun1
589aef4bb0 Fix py codegen to delete values that don't have any users (#131028)
Fixes #131025

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131028
Approved by: https://github.com/ezyang
2024-08-01 03:18:37 +00:00
Xuehai Pan
918ece4f4d [BE][Easy][11/19] enforce style for empty lines in import segments in test/dy*/ (#129762)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129762
Approved by: https://github.com/anijain2305
2024-07-27 17:43:53 +00:00
PyTorch MergeBot
c3679bed35 Revert "Fix py codegen to delete values that don't have any users (#131028)"
This reverts commit 91aba7baac.

Reverted https://github.com/pytorch/pytorch/pull/131028 on behalf of https://github.com/clee2000 due to broke inductor/test_triton_kernels inductor/test_triton_kernels.py::KernelTests::test_triton_kernel_functionalize [GH job link](https://github.com/pytorch/pytorch/actions/runs/10094659640/job/27915271250) [HUD commit link](91aba7baac) ([comment](https://github.com/pytorch/pytorch/pull/131028#issuecomment-2251058374))
2024-07-25 17:42:18 +00:00
YangQun1
91aba7baac Fix py codegen to delete values that don't have any users (#131028)
Fixes #131025

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131028
Approved by: https://github.com/ezyang
2024-07-25 13:04:23 +00:00
PyTorch MergeBot
8ffd109a00 Revert "Fix py codegen to delete values that don't have any users (#131028)"
This reverts commit 466c167b71.

Reverted https://github.com/pytorch/pytorch/pull/131028 on behalf of https://github.com/atalman due to breaks CI ([comment](https://github.com/pytorch/pytorch/pull/131028#issuecomment-2247771530))
2024-07-24 12:21:43 +00:00
YangQun1
466c167b71 Fix py codegen to delete values that don't have any users (#131028)
Fixes #131025

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131028
Approved by: https://github.com/ezyang
2024-07-24 01:03:56 +00:00
Edward Z. Yang
022adf8c5e Fix bug for comptime.get_local for cells/closures (#126637)
I wasn't paying enough attention and didn't notice that LOAD_DEREF is
defined differently for InliningInstructionTranslator.  Match it up with
the code there.

This also fixes comptime.print(), which was broken, because closing over
an argument turned it into a cell rather than a regular local.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126637
Approved by: https://github.com/yanboliang
2024-05-20 17:51:28 +00:00
Edward Z. Yang
9c9d0c2fab Add VariableTracker.debug_repr (#126299)
Now you can print arbitrary values at compile time with
comptime.print()

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126299
Approved by: https://github.com/jansel
ghstack dependencies: #126292
2024-05-15 23:55:29 +00:00
Animesh Jain
c568b84794 [dynamo][guards] Move backend match to eval_frame (#121954)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121954
Approved by: https://github.com/jansel
2024-03-17 06:52:10 +00:00
David Berard
5c0976fa04 Revert "[dynamo] guarded config (#111299)" (#115386)
This reverts commit 5927e9cbf2.

Differential Revision: [D51959266](https://our.internmc.facebook.com/intern/diff/D51959266)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115386
Approved by: https://github.com/yanboliang, https://github.com/malfet
ghstack dependencies: #115384, #115401, #115385
2023-12-11 19:35:42 +00:00
Jon Chuang
5927e9cbf2 [dynamo] guarded config (#111299)
---

Fixes: https://github.com/pytorch/pytorch/issues/110682

Replaces: https://github.com/pytorch/pytorch/pull/111074

The guards are installed based on config that is valid at the call to `torch.compile`, rather than at any subsequent call / triggered compilation. Subsequent compilations will restore the config if there is a config mismatch of the existing global config with the saved config.

TODO:
- [X] add tests

Follow up PRs:
- [x] add revised cache size computation (follow up PR: #111300 , based on: https://github.com/pytorch/pytorch/pull/107496)
- [ ] handle run-only mode?
- [ ] config restoration itself is not thread-safe (tracked: https://github.com/pytorch/pytorch/issues/111150)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111299
Approved by: https://github.com/ezyang
2023-11-17 09:59:58 +00:00
ydwu4
94a54b89aa [dynamo] Add BACKEND_MATCH guard to detect and recompile when backend changes (#107337)
**Motivation:**
We try to make torch.cond use torch.compile automatically so that we could error out when there is side-effects in the branches and correctly handle the closures.

Before this PR, we have a warning if we don't turn on a config raise_on_backend_change (turning it on gives us an error) for the following code:
```python
def foo()

# Inside torch.cond, we'd like to do something like
torch.compile(foo, backend="eager", fullgraph=True)(...)
...
# Users may then call torch.compile somewhere else.
# Dynamo will use the cached code of foo for "eager" backend
# but we expect dynamo to recompile with "inductor" backend.
torch.compile(foo, backend="inductor")(...)
```

This PR adds a BACKEND_MATCH guard. Effectively, it implements a per-backend cache. In the above example, the cached code for "eager" won't work for "inductor" due to guard check failures and the second torch.compile will do a re-compilation. In the future, it might be useful to have something like a configuration guard that guards against dynamo configuration changes across different compiles (e.g. compile a function with fullgraph=False then compile it again with fullgraph=True).

**Implementation:**
1. We add a guarded_backend_cache and check the most_recent_backend against the backend associated with cached code. We also remove the raise_on_backend_change flag.

Note: More lines are printed for debug log due to newly added context manager and guard adds .

**Test Plan:**
Removed original tests that raise on different backend and add a new test to test whether the BACKEND_MATCH guard can guard against backend change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107337
Approved by: https://github.com/jansel
2023-09-14 15:49:30 +00:00
Michael Voznesensky
de0b18fad9 Use user directed names for variables where possible (#109092)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109092
Approved by: https://github.com/ezyang
ghstack dependencies: #108846
2023-09-13 07:44:04 +00:00
PyTorch MergeBot
38fcf77a1b Revert "[dynamo] Add BACKEND_MATCH guard to detect and recompile when backend changes (#107337)"
This reverts commit 1a64ec7dd4.

Reverted https://github.com/pytorch/pytorch/pull/107337 on behalf of https://github.com/huydhn due to Sorry for reverting your change but inductor perf smoke test starts to regress after this ([comment](https://github.com/pytorch/pytorch/pull/107337#issuecomment-1710974588))
2023-09-08 02:03:48 +00:00
ydwu4
1a64ec7dd4 [dynamo] Add BACKEND_MATCH guard to detect and recompile when backend changes (#107337)
**Motivation:**
We try to make torch.cond use torch.compile automatically so that we could error out when there is side-effects in the branches and correctly handle the closures.

Before this PR, we have a warning if we don't turn on a config raise_on_backend_change (turning it on gives us an error) for the following code:
```python
def foo()

# Inside torch.cond, we'd like to do something like
torch.compile(foo, backend="eager", fullgraph=True)(...)
...
# Users may then call torch.compile somewhere else.
# Dynamo will use the cached code of foo for "eager" backend
# but we expect dynamo to recompile with "inductor" backend.
torch.compile(foo, backend="inductor")(...)
```

This PR adds a BACKEND_MATCH guard. Effectively, it implements a per-backend cache. In the above example, the cached code for "eager" won't work for "inductor" due to guard check failures and the second torch.compile will do a re-compilation. In the future, it might be useful to have something like a configuration guard that guards against dynamo configuration changes across different compiles (e.g. compile a function with fullgraph=False then compile it again with fullgraph=True).

**Implementation:**
1. We add a guarded_backend_cache and check the most_recent_backend against the backend associated with cached code. We also remove the raise_on_backend_change flag.

2. Then newly added context manager and guard adds more lines for debug log so we change the uppper limit from 50 to 55.

**Test Plan:**
Removed original tests that raise on different backend and add a new test to test whether the BACKEND_MATCH guard can guard against backend change.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107337
Approved by: https://github.com/jansel
2023-09-07 22:45:54 +00:00
Michael Lazos
05eea20eb9 [dynamo] Simulate torch function enablement state (#105091)
Part of https://github.com/pytorch/pytorch/issues/93723

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105091
Approved by: https://github.com/voznesenskym, https://github.com/anijain2305
2023-07-13 17:42:20 +00:00
Edward Z. Yang
3804eb109a Always register SHAPE_ENV guard (#103521)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103521
Approved by: https://github.com/Skylion007
2023-06-13 20:15:20 +00:00
Mark Saroufim
95fced4483 Pretty dataclass dynamo explain (#102869)
Also thinking out loud: maybe we only print graph break reasons? And for the rest we have a verbose print which prints everything?

TODO: some tests are failing based on what they expect a guard string to look like, easy to fix i'll do it early next week

# After

```
(sourcetorch) ubuntu@ip-172-31-1-136:~/test$ python pretty.py
BREAK
Graph Count: 2
Graph Break Count: 1
Op Count: 2
Break Reasons:
  Break Reason 1:
    Reason: call_function BuiltinVariable(print) [ConstantVariable(str)] {}
    User Stack:
      <FrameSummary file /home/ubuntu/test/pretty.py, line 6 in fn>
Ops per Graph:
  Ops 1:
    <built-in function add>
  Ops 2:
    <built-in function add>
Out Guards:
  Guard 1:
    Name: ''
    Source: global
    Create Function: GRAD_MODE
    Guard Types: ['GRAD_MODE']
    Code List: ['___is_grad_enabled()']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 2:
    Name: ''
    Source: global
    Create Function: DEFAULT_DEVICE
    Guard Types: ['DEFAULT_DEVICE']
    Code List: ['utils_device.CURRENT_DEVICE == None']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 3:
    Name: "G['print']"
    Source: global
    Create Function: BUILTIN_MATCH
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 4:
    Name: ''
    Source: global
    Create Function: DETERMINISTIC_ALGORITHMS
    Guard Types: ['DETERMINISTIC_ALGORITHMS']
    Code List: ['not ___are_deterministic_algorithms_enabled()']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 5:
    Name: "L['x']"
    Source: local
    Create Function: TENSOR_MATCH
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 6:
    Name: ''
    Source: global
    Create Function: GRAD_MODE
    Guard Types: ['GRAD_MODE']
    Code List: ['___is_grad_enabled()']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 7:
    Name: ''
    Source: global
    Create Function: DEFAULT_DEVICE
    Guard Types: ['DEFAULT_DEVICE']
    Code List: ['utils_device.CURRENT_DEVICE == None']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 8:
    Name: ''
    Source: global
    Create Function: DETERMINISTIC_ALGORITHMS
    Guard Types: ['DETERMINISTIC_ALGORITHMS']
    Code List: ['not ___are_deterministic_algorithms_enabled()']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 9:
    Name: "L['x']"
    Source: local
    Create Function: TENSOR_MATCH
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
Compile Times: TorchDynamo compilation metrics:
Function                        Runtimes (s)
------------------------------  --------------
_compile                        0.0164, 0.0035
OutputGraph.call_user_compiler  0.0000, 0.0000
```

## Before

```
('Dynamo produced 2 graphs with 1 graph break and 2 ops', [{Guard(name='print', source=<GuardSource.GLOBAL: 1>, create_fn=<function GuardBuilder.BUILTIN_MATCH at 0x7f92ea5009d0>, is_volatile=False, guard_types=None, code_list=None, obj_weakref=None, guarded_class_weakref=None), Guard(name='x', source=<GuardSource.LOCAL: 0>, create_fn=<function GuardBuilder.TENSOR_MATCH at 0x7f92ea501000>, is_volatile=False, guard_types=['TENSOR_MATCH'], code_list=None, obj_weakref=<weakref at 0x7f9224d28f40; dead>, guarded_class_weakref=<weakref at 0x7f92d81734c0; to 'torch._C._TensorMeta' at 0x540b610 (Tensor)>)}, {Guard(name='x', source=<GuardSource.LOCAL: 0>, create_fn=<function GuardBuilder.TENSOR_MATCH at 0x7f92ea501000>, is_volatile=False, guard_types=['TENSOR_MATCH'], code_list=None, obj_weakref=<weakref at 0x7f9224d5e700; dead>, guarded_class_weakref=<weakref at 0x7f92d81734c0; to 'torch._C._TensorMeta' at 0x540b610 (Tensor)>)}], [GraphModule(), GraphModule()], [[<built-in function add>], [<built-in function add>]], [GraphCompileReason(reason='call_function BuiltinVariable(print) [ConstantVariable(str)] {}', user_stack=[<FrameSummary file <ipython-input-1-9e2ddb639697>, line 6 in fn>]), GraphCompileReason(reason='return_value', user_stack=[<FrameSummary file <ipython-input-1-9e2ddb639697>, line 8 in <graph break in fn>>])], 'Dynamo produced 2 graphs with 1 graph break and 2 ops\n Break reasons: \n\n1. call_function BuiltinVariable(print) [ConstantVariable(str)] {}\n  File "<ipython-input-1-9e2ddb639697>", line 6, in fn\n    print("BREAK")\n \n2. return_value\n  File "<ipython-input-1-9e2ddb639697>", line 8, in <graph break in fn>\n    return x\n \nTorchDynamo compilation metrics:\nFunction                        Runtimes (s)\n------------------------------  --------------\n_compile                        0.0418, 0.0084\nOutputGraph.call_user_compiler  0.0001, 0.0001')

```

## Program

```python
import torch
import torch._dynamo

def fn(x):
    x = x + 1
    print("BREAK")
    x = x + 1
    return x

out = torch._dynamo.explain(fn, torch.randn(10))
print(out)

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102869
Approved by: https://github.com/voznesenskym
2023-06-07 22:38:57 +00:00
Brian Hirsh
98ab11a2c3 separate out dynamo .requires_grad and .is_grad_enabled guards (#100570)
Fixes https://github.com/pytorch/pytorch/issues/100977

This will hopefully fix this error (from [issue](https://github.com/pytorch/pytorch/issues/99616))

This PR fixes an internal model: we were running an inductor inference graph, but `torch.is_grad_enabled()` was True, causing us to error inside of the inference graph when we encountered an out= operator.

I haven't been able to create a smaller repro - before landing this, I want to create a smaller repro to convince myself of why we need to separate out these guards.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100570
Approved by: https://github.com/ezyang
2023-05-24 14:58:40 +00:00
PyTorch MergeBot
7f3fed125e Revert "separate out dynamo .requires_grad and .is_grad_enabled guards (#100570)"
This reverts commit 1fabee399d.

Reverted https://github.com/pytorch/pytorch/pull/100570 on behalf of https://github.com/PaliC due to breaking inductor tests along with #101219 ([comment](https://github.com/pytorch/pytorch/pull/100570#issuecomment-1555271267))
2023-05-19 21:29:09 +00:00
Brian Hirsh
1fabee399d separate out dynamo .requires_grad and .is_grad_enabled guards (#100570)
Fixes https://github.com/pytorch/pytorch/issues/100977

This will hopefully fix this error (from [issue](https://github.com/pytorch/pytorch/issues/99616))

This PR fixes an internal model: we were running an inductor inference graph, but `torch.is_grad_enabled()` was True, causing us to error inside of the inference graph when we encountered an out= operator.

I haven't been able to create a smaller repro - before landing this, I want to create a smaller repro to convince myself of why we need to separate out these guards.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100570
Approved by: https://github.com/ezyang
2023-05-19 16:14:56 +00:00
Edward Z. Yang
e47e8c9d98 Guard on default device (#99551)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99551
Approved by: https://github.com/voznesenskym, https://github.com/mlazos
2023-04-20 17:02:59 +00:00
Michael Voznesensky
b1e60bfb6a Pass f_locals as a dict rather than kwargs (#98107)
Fixes https://github.com/pytorch/pytorch/issues/97688

One big problem is that instead of printing x < y we now print
`E["x"] < E["y"]` and now all of the tests wobbled and I'm mad.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98107
Approved by: https://github.com/ezyang
2023-04-04 00:30:08 +00:00
William Wen
14ef91cea6 [dynamo 3.11] small bug fixes (#96508)
Bugs fixed:
	- CALL_FUNCTION_EX expects null pop in symbolic_convert
	- make_function_with_closure codegen requires a push_null
	- copy over the closure in eval_frame.c
	- add JUMP_FORWARD to terminal opcodes
	- enum repr fix in utils.py
	- fix symbolic_convert's break_graph_if_unsupported wrapper

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96508
Approved by: https://github.com/jansel
2023-03-31 18:18:12 +00:00
Sam Gross
87f5e92916 [dynamo] Add guards for deterministic algos (#96695)
Inductor now falls back to eager mode for deterministic algos. Add guards in dynamo to check if the deterministic algos mode changes.

See #93537

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96695
Approved by: https://github.com/ngimel, https://github.com/jansel
2023-03-31 16:28:45 +00:00
Edward Z. Yang
7dd95ad7f3 Add a convenience shortcut for accessing size on ComptimeVar (#95404)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95404
Approved by: https://github.com/voznesenskym
2023-02-27 02:02:02 +00:00
Edward Z. Yang
dfe916ca88 Dynamo comptime, with public ComptimeContext API (#90983)
This PR adds `@comptime`, a decorator that causes a given function to be executed at compile time when Dynamo is symbolically evaluating their program. To query the Dynamo state, we offer a public ComptimeContext API which provides a limited set of APIs for querying Dynamo's internal state. We intend for users to use this API and plan to keep it stable. Here are some things you can do with it:

* You want to breakpoint Dynamo compilation when it starts processing a particular line of user code: give comptime a function that calls breakpoint
* You want to manually induce a graph break for testing purposes; give comptime a function that calls unimplemented
* You want to perform a debug print, but you don't want to induce a graph break; give comptime a function that prints.
* You can print what the symbolic locals at a given point in time are.
* You can print out the partial graph the Dynamo had traced at this point.
* (My original motivating use case.) You want to add some facts to the shape env, so that a guard evaluation on an unbacked SymInt doesn't error with data-dependent. Even if you don't know what the final user API for this should be, with comptime you can hack out something quick and dirty. (This is not in this PR, as it depends on some other in flight PRs.)

Check out the tests to see examples of comptime in action.

In short, comptime is a very powerful debugging tool that lets you drop into Dynamo from user code, without having to manually jerry-rig pdb inside Dynamo to trigger after N calls.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90983
Approved by: https://github.com/jansel
2022-12-19 11:06:01 +00:00