Commit Graph

74 Commits

Author SHA1 Message Date
Ryan Guo
ea1d11cf74 [dynamo] Represent all cells as NewCellVariable (#140153)
In addition to `NewCellVariable`, Dynamo has 3 ways of modeling cell objects:
1. For cells captured and created by the root frame, represent them as
   their contents in `root_tx.symbolic_locals`, which `LOAD_DEREF` and
   `STORE_DEREF` update directly, without going through `SideEffects`.
2. `ClosureVariable`: this is created when cells from (1) are captured
   by a newly created function Dynamo is about to inline. It's a handle
   with a name that redirects `LOAD_DEREF` and `STORE_DEREF` back (1),
   to make `root_tx.symbolic_locals` up-to-date.
3. For cells that are captured by both the root frame and some
   pre-existing function Dynamo is about to inline, represent those
   cells as contents, and do not allow writes to them.

Note that (2) and (3) are mainly to conform with (1) -- to make sure
Dynamo has a consistent modeling of cells for the same cell objects.

In this patch, we represent all of these cells as `NewCellVariable`. The
main new code paths introduced are:
- using `NewCellVariable` to model cell objects created by the root
  frame (the cells are passed in as input to `InstructionTranslator`),
  this is what allows us to get rid of all 3 legacy paths above.
- adding a new `AutoDerefLocalSource` to deal with the python-code
  level (guards) and bytecode level (codegen) auto-dereferencing
  behavior, when accessing pre-existing python cells. This also
  involves a tiny update to guard manager generation.
- plumbing some extra info into `LocalSource` and `CellVariable` so that
  we can still emit `LOAD_DEREF`, `STORE_DEREF`, `LOAD_CLOSURE` (instead
  of `make_cell`, `cell_contents` attribute access, and `LOAD_FAST`),
  which is important for readability, performance, and some
  assumptions `bytecode_transformation.py` makes.

As a result, this patch removes a lot of the now-dead code paths and
TODOs. Notably, it significantly simplified the `prune_dead_locals`
function, which was duplicating a lot of the logic from
`prune_dead_object_new`; this conveniently closes #137123.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140153
Approved by: https://github.com/jansel
ghstack dependencies: #140330, #140152, #140436, #140435
2024-11-15 17:17:30 +00:00
Ryan Guo
f459c3095f [dynamo] Document codegen and clean up some code paths (#139670)
This patch
1. Adds documentation to `PyCodegen.__call__`, `PyCodegen.tempvars` and
   the `allow_cache` flag.
2. Merges a few existing code paths in `PyCodegen.__call__`.
3. removes the `elif var in cg.tempvars` code path in
   `codegen_save_tempvars`, because it's no longer needed after #113725,
   as we have up-to-date `VariableTracker.source` now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139670
Approved by: https://github.com/jansel
ghstack dependencies: #139538
2024-11-07 03:14:16 +00:00
Ryan Guo
183b386cb2 [dynamo] Simplify Codegen for variables with MutableSideEffects (#139538)
This effectively undoes #115095, which is not longer be needed after #113725.

Why did we need #115095? I went back in history and found that [this line](https://github.com/pytorch/pytorch/pull/113725/files#diff-0bb1756725c4426408938314b0c9d3988ae5bf49994892d7038ad7746e209e9fR86)
actually fixed what #115095 fixed. Specifically, without the
`allow_cache` check for the "dup_top" optimization, we could incorrectly
codegen based on source, despite `codegen_update_mutated` requested to
codegen from value, for updates to pre-existing lists, etc. Since #113725 added
the `allow_cache` check, we no longer need the `mutable_side_effects_from_source`
code path from #115095.

However, #115442 introduced a `value_from_source` flag which didn't
account for the `mutable_side_effects_from_source` branch. So this patch
adds an extra check to keep existing behavior for export, and leaves a
TODO for investigating what exactly export wants from codegen, when it
comes to side effects and sources.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139538
Approved by: https://github.com/jansel
2024-11-07 03:14:16 +00:00
Ryan Guo
693a0a1bd4 [dynamo][NFC] Rename mutable_local and add documentation (#139339)
This patch addresses the renaming part of #133027, specifically, it
renames the following and adds documentation for relevant classes.
1. `VariableTracker.mutable_local` to `mutation_type`
2. `MatableLocal `to `ValueMutationNew`
3. `MutableSideEffects `to `ValueMutationExisting`
4. `MutableLocalSource` to `SourceType`
5. `MutableLocalSource.Local` to `New`

Note that (2), (3) and (5) are mainly to bring consistency between them
and `AttributeMutationNew`, `AttributeMutationExisting`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139339
Approved by: https://github.com/jansel, https://github.com/mlazos, https://github.com/anijain2305
2024-11-05 19:11:41 +00:00
Bob Ren
500b2bc781 Have as_tensor always return a float64 tensor in dynamo (#138598)
As discussed with @ezyang, this set of diffs are extracting fixes to problems discovered to flipping `specialize_float=False` in https://github.com/pytorch/pytorch/pull/137782. Since these codepaths are exercised in existing tests, I'm going to bias towards shipping speed and put these up with the primary test plan as the global CI. These code paths are all tested via existing tests when `specialize_float=False` and it feels a bit wonky to add more gated tests that only test behavior when this flag is True, especially since these code paths are already covered. That being said, I'm happy to add individual tests if reviewers insist or have a different POV.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138598
Approved by: https://github.com/ezyang
ghstack dependencies: #138595
2024-10-24 20:50:28 +00:00
Ryan Guo
162eba2dee [dynamo] Remove mutable_local.source and index on VariableTracker rather than MutableLocalBase (#137905)
This patch addresses parts of the side-effect refactor proposed in #133027;
specifically, it does 3 things:

1. Change `SideEffects.store_attr_mutations` and `PyCodegen.tempvars`
   to index on `VariableTracker` rather than `MutableLocalBase`.
2. Remove the `source` field from `MutableSideEffects` and
   `AttributeMutation`, and use `VariableTracker.source` instead.
3. Plumb a `overridden_sources: Dict[Source, Source]` from
   `handle_aliases_for_stolen_lists` to `PyCodegen` so that we don't
   update `VariableTracker.source` in place, while still preserving what
   `handle_aliases_for_stolen_lists` needed (i.e., modifying codegen for
   certain `VariableTracker`).

(1) and (2) are merged in 1 patch because of some dependency between
a. `OutputGraph.handle_aliases_for_stolen_lists` which iterates over
   `sideSideEffects.store_attr_mutations.keys()`, and potentially update
   its source field to be completely different.
b. `SideEffects.codegen_update_mutated`, which happens after the above
   and uses `cg(var.mutable_local.source)`.
where if we apply (1) only, (b) breaks, and if we apply (2) only, (a)
breaks.

(3) is needed for correctness, see comments in the PR for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137905
Approved by: https://github.com/jansel, https://github.com/anijain2305, https://github.com/mlazos
2024-10-18 20:20:42 +00:00
Edward Z. Yang
beb46de342 Correctly convert Python float to float64 when passing argument as Tensor (#136413)
I can't actually test the Dynamo codegen fix as it is impossible to
directly use the Tensor at the moment.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136413
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #136599
2024-09-26 16:50:13 +00:00
PyTorch MergeBot
0133fbcfe7 Revert "Correctly convert Python float to float64 when passing argument as Tensor (#136413)"
This reverts commit f0f79dd8f1.

Reverted https://github.com/pytorch/pytorch/pull/136413 on behalf of https://github.com/ezyang due to forward fix is stuck, revert this ([comment](https://github.com/pytorch/pytorch/pull/136413#issuecomment-2372404873))
2024-09-24 21:20:37 +00:00
Edward Z. Yang
f0f79dd8f1 Correctly convert Python float to float64 when passing argument as Tensor (#136413)
I can't actually test the Dynamo codegen fix as it is impossible to
directly use the Tensor at the moment.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136413
Approved by: https://github.com/bobrenjc93
2024-09-23 16:48:08 +00:00
Oguz Ulgen
6e79932543 Add basic mypy annotations to dynamo (#132415)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132415
Approved by: https://github.com/XuehaiPan, https://github.com/jamesjwu
2024-08-04 18:43:36 +00:00
PyTorch MergeBot
3558a8cf4a Revert "Add basic mypy annotations to dynamo (#132415)"
This reverts commit 71e22e0959.

Reverted https://github.com/pytorch/pytorch/pull/132415 on behalf of https://github.com/ZainRizvi due to Sorry, this PR has entered a weird state in the diff train. Trying to revert it to skip it, and then we can try relanding it ([comment](https://github.com/pytorch/pytorch/pull/132415#issuecomment-2267631785))
2024-08-04 18:39:29 +00:00
William Wen
625af2d27c [dynamo] fix add_push_null callsites with CALL_FUNCTION_EX (#132329)
Also fix a bug in `PyCodegen.add_push_null` where in Python <= 3.12, we may accidentally duplicate a NULL instead of the object on the stack before it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132329
Approved by: https://github.com/anijain2305
2024-08-02 00:29:21 +00:00
Oguz Ulgen
71e22e0959 Add basic mypy annotations to dynamo (#132415)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132415
Approved by: https://github.com/XuehaiPan, https://github.com/jamesjwu
2024-08-01 20:14:25 +00:00
Xuehai Pan
e74ba1b34a [BE][Easy][15/19] enforce style for empty lines in import segments in torch/_d*/ (#129767)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129767
Approved by: https://github.com/anijain2305
2024-07-31 21:18:11 +00:00
William Wen
375a4d7e9e [3.13, dynamo] decompose fused load/store instructions (#130569)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130569
Approved by: https://github.com/jansel
ghstack dependencies: #130566, #130567, #130568
2024-07-22 18:07:40 +00:00
William Wen
4319147ca9 [3.13, dynamo] fix closures, MAKE_FUNCTION, LOAD_CLOSURE; support SET_FUNCTION_ATTRIBUTE (#130566)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130566
Approved by: https://github.com/jansel
2024-07-22 18:07:28 +00:00
William Wen
539acf7656 [3.13, dynamo] support CALL_KW (#130564)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130564
Approved by: https://github.com/jansel
ghstack dependencies: #130459, #130460, #130461
2024-07-17 09:47:58 +00:00
William Wen
92ac9ee83c [3.13, dynamo] swap null and pop_null in codegen (#130383)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130383
Approved by: https://github.com/jansel
2024-07-13 23:31:57 +00:00
William Wen
79aabaf626 [3.13, dynamo] codegen PUSH_NULL when callable is codegen'd (#129172)
Significant bytecode generation API change!

The new suggested convention to generating bytecode to call a function is now to wrap instructions that push a callable to the stack with `add_push_null`, then that callable is called with `create_call_function` with `push_null=False` (see diff for examples).

In Python 3.13, NULL is now expected to be pushed after the callable. In <=3.12, the NULL was pushed before the callable.  This change abstracts away the exact placement of the NULL, but the developer must be aware that a NULL may be needed when codegen'ing a callable.

This abstraction also reduces the need for the `push_null=True` option in `create_call_function`, which removes the need to rotate a NULL to the right place on the stack with a sequence of `SWAP` instructions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129172
Approved by: https://github.com/jansel
2024-06-22 17:25:23 +00:00
Aaron Orenstein
dcfa7702c3 Flip default value for mypy disallow_untyped_defs [1/11] (#127838)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127838
Approved by: https://github.com/oulgen
2024-06-08 18:16:33 +00:00
Edward Z. Yang
2ba102f689 Implement native support for float inputs in Dynamo and ShapeEnv (#125325)
The big idea is that floats are treated as Tensors on input/output to the FX graph, but on the inside, we immediately call item() on the synthetic Tensor and record regular float operations on it. Canonicalization to Tensor operations will happen in a standalone FX pass. This behavior is controlled by `specialize_float` config variable when set to False.

The generated graph looks like this for the test `test_unspec_float_output`:

```
 def forward(self, L_x_: "f32[3]", L_y_: "f32[]"):
     l_x_ = L_x_
     l_y_ = L_y_

     # File: /data/users/ezyang/a/pytorch/test/dynamo/test_unspec.py:511 in f, code: return x + 1, y * 2
     add: "f32[3]" = l_x_ + 1;  l_x_ = None
     item: "Sym(zf0)" = l_y_.item();  l_y_ = None
     mul: "Sym(2*zf0)" = item * 2;  item = None
     scalar_tensor: "f32[]" = torch.scalar_tensor(mul);  mul = None
     return (add, scalar_tensor)
```

The ingredients:

* **torch/_dynamo/variables/builder.py** When `specialize_float` is False, we wrap float literals with `wrap_symfloat`. This is an unholy mashup of `wrap_symint` and `wrap_unspecialized_primitive`. The overall strategy is that we first generate a tensor argument (because that's what we want to show up into the FX graph), but then immediately call item() on the tensor argument to get a SymNodeVariable, which we will do the rest of the tracing with.  Importantly, this SymNodeVariable is backed with the source of the original float: this means we can guard on the resulting value (something we could NOT do with UnspecializedPythonVariable). This has to be done manually, because if you literally call item() on the tensor, you will end up with an unbacked float. There is a bit of copy paste from wrap_symint and wrap_unspecialized_primitive which we can try to factor out, but this really is its own thing and you should review every line of code in the function.
* **torch/fx/experimental/symbolic_shapes.py** We now can generate guards on float inputs, and these guards are handled inside of ShapeEnv. So we need to be able to allocate (backed!) float symbols, and produce guards for them. Fairly straightforward generalization.
* **torch/_dynamo/codegen.py** I also need to maintain the invariant that there are no float outputs to the FX graph. I chose to do this at codegen time. When we detect a SymNodeVariable on the return stack for a float, we on the fly convert it (via `as_tensor`) to a TensorVariable, which is the true output. We then special case the output bytecode to call item() on it again. The tensor conversion is memoized on SymNodeVariable since we typically run the code generation process twice.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125325
Approved by: https://github.com/lezcano, https://github.com/jansel
2024-05-14 04:10:01 +00:00
Edward Z. Yang
650a248d3e Rename is_unspecialized to pass_arg_as_tensor, add comment (#125496)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125496
Approved by: https://github.com/lezcano
ghstack dependencies: #125395, #125419, #125483, #125494
2024-05-05 16:57:50 +00:00
Jason Ansel
212e460dce [dynamo] Support custom __setattr__ on UserDefinedObjectVariable (#123318)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123318
Approved by: https://github.com/anijain2305
2024-04-07 21:06:52 +00:00
William Wen
01547960bc [dynamo, 3.12] remove LOAD_METHOD, update LOAD_ATTR (#122356)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122356
Approved by: https://github.com/jansel
ghstack dependencies: #122146, #122335, #122354, #122355
2024-03-27 20:39:39 +00:00
William Wen
3a67c86f72 [dynamo, 3.12] remove references to PRECALL instruction in 3.12 (#122354)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122354
Approved by: https://github.com/jansel
ghstack dependencies: #122146, #122335
2024-03-27 20:39:39 +00:00
Yifu Wang
09cb42ce29 [dynamo] delete graph_out_{n} after restoring local vars (#122658)
At graph breaks, we create a graph_out_{n} symbol to hold the graph output and
use it to restore the local vars. In addition to their own symbols, the local
vars are kept alive by the symbol we created. This means that if the graph
break is the last usage of one of the symbols, the symbol would still be kept
alive upon graph resumption.

This PR: delete the graph_out_{n} symbol after restoring local vars so the
lifetime of the local vars is governed by themselves.

## Example Problem
Tensor `b`'s last usage is in the graph break. However, it won't be deallocated until `bar()` completes. In the orignal issue report by @Yuzhen11, `b` is a large tensor and `bar()` is an expensive computation.

```python
import torch

def foo(a):
    return torch.mm(a, a)

@torch._dynamo.disable()
def graph_break_fn(a):
    ret = a.bfloat16()
    return ret

def bar(c):
    return torch.mm(c, c)

def fn(a):
    b = foo(a)
    c = graph_break_fn(b)
    # del b
    return bar(c)

fn_compiled = torch.compile(fn, backend="eager")
a = torch.randn(10000, 10000, device="cuda", requires_grad=True)

fn_compiled(a).sum().backward()
```

Bytecode before this PR:
```
ORIGINAL BYTECODE fn /home/yifu/microbench/del2.py line 18
 19           0 LOAD_GLOBAL              0 (foo)
              2 LOAD_FAST                0 (a)
              4 CALL_FUNCTION            1
              6 STORE_FAST               1 (b)

 20           8 LOAD_GLOBAL              1 (graph_break_fn)
             10 LOAD_FAST                1 (b)
             12 CALL_FUNCTION            1
             14 STORE_FAST               2 (c)

 22          16 LOAD_GLOBAL              2 (bar)
             18 LOAD_FAST                2 (c)
             20 CALL_FUNCTION            1
             22 RETURN_VALUE

MODIFIED BYTECODE fn /home/yifu/microbench/del2.py line 18
 18           0 LOAD_GLOBAL              3 (__compiled_fn_0)
              2 LOAD_FAST                0 (a)
              4 CALL_FUNCTION            1
              6 STORE_FAST               3 (graph_out_0)
              8 LOAD_GLOBAL              1 (graph_break_fn)
             10 LOAD_FAST                3 (graph_out_0)
             12 LOAD_CONST               1 (0)
             14 BINARY_SUBSCR

 20          16 CALL_FUNCTION            1
             18 LOAD_GLOBAL              4 (__resume_at_14_1)
             20 ROT_TWO
             22 CALL_FUNCTION            1
             24 RETURN_VALUE

ORIGINAL BYTECODE torch_dynamo_resume_in_fn_at_20 /home/yifu/microbench/del2.py line 20
 20           0 LOAD_FAST                0 (___stack0)
              2 JUMP_ABSOLUTE            9 (to 18)
              4 LOAD_GLOBAL              0 (foo)
              6 LOAD_FAST                1 (a)
              8 CALL_FUNCTION            1
             10 STORE_FAST               2 (b)
             12 LOAD_GLOBAL              1 (graph_break_fn)
             14 LOAD_FAST                2 (b)
             16 CALL_FUNCTION            1
        >>   18 STORE_FAST               3 (c)

 22          20 LOAD_GLOBAL              2 (bar)
             22 LOAD_FAST                3 (c)
             24 CALL_FUNCTION            1
             26 RETURN_VALUE

MODIFIED BYTECODE torch_dynamo_resume_in_fn_at_20 /home/yifu/microbench/del2.py line 20
 20           0 LOAD_GLOBAL              3 (__compiled_fn_2)
              2 LOAD_FAST                0 (___stack0)
              4 CALL_FUNCTION            1
              6 UNPACK_SEQUENCE          1
              8 RETURN_VALUE
```

Bytecode after this PR:
```
ORIGINAL BYTECODE fn /home/yifu/microbench/del2.py line 18
 19           0 LOAD_GLOBAL              0 (foo)
              2 LOAD_FAST                0 (a)
              4 CALL_FUNCTION            1
              6 STORE_FAST               1 (b)

 20           8 LOAD_GLOBAL              1 (graph_break_fn)
             10 LOAD_FAST                1 (b)
             12 CALL_FUNCTION            1
             14 STORE_FAST               2 (c)

 22          16 LOAD_GLOBAL              2 (bar)
             18 LOAD_FAST                2 (c)
             20 CALL_FUNCTION            1
             22 RETURN_VALUE

MODIFIED BYTECODE fn /home/yifu/microbench/del2.py line 18
 18           0 LOAD_GLOBAL              3 (__compiled_fn_0)
              2 LOAD_FAST                0 (a)
              4 CALL_FUNCTION            1
              6 STORE_FAST               3 (graph_out_0)
              8 LOAD_GLOBAL              1 (graph_break_fn)
             10 LOAD_FAST                3 (graph_out_0)
             12 LOAD_CONST               1 (0)
             14 BINARY_SUBSCR
             16 DELETE_FAST              3 (graph_out_0)

 20          18 CALL_FUNCTION            1
             20 LOAD_GLOBAL              4 (__resume_at_14_1)
             22 ROT_TWO
             24 CALL_FUNCTION            1
             26 RETURN_VALUE

ORIGINAL BYTECODE torch_dynamo_resume_in_fn_at_20 /home/yifu/microbench/del2.py line 20
 20           0 LOAD_FAST                0 (___stack0)
              2 JUMP_ABSOLUTE            9 (to 18)
              4 LOAD_GLOBAL              0 (foo)
              6 LOAD_FAST                1 (a)
              8 CALL_FUNCTION            1
             10 STORE_FAST               2 (b)
             12 LOAD_GLOBAL              1 (graph_break_fn)
             14 LOAD_FAST                2 (b)
             16 CALL_FUNCTION            1
        >>   18 STORE_FAST               3 (c)

 22          20 LOAD_GLOBAL              2 (bar)
             22 LOAD_FAST                3 (c)
             24 CALL_FUNCTION            1
             26 RETURN_VALUE

MODIFIED BYTECODE torch_dynamo_resume_in_fn_at_20 /home/yifu/microbench/del2.py line 20
 20           0 LOAD_GLOBAL              3 (__compiled_fn_2)
              2 LOAD_FAST                0 (___stack0)
              4 CALL_FUNCTION            1
              6 UNPACK_SEQUENCE          1
              8 RETURN_VALUE

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122658
Approved by: https://github.com/jansel, https://github.com/anijain2305
2024-03-26 22:49:05 +00:00
Jason Ansel
5b5c167adc [dynamo] Add some helpers to PyCodegen (#120684)
This are used in later PRs in the stack

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120684
Approved by: https://github.com/yanboliang
2024-02-27 18:46:51 +00:00
Jason Ansel
2fea475215 [dynamo] Refactor reconstruct() not to return anything (#120150)
This simplifies things slightly and avoids some bugs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120150
Approved by: https://github.com/yanboliang
2024-02-17 17:13:41 +00:00
Jason Ansel
39c68efd85 [dynamo] Capture untyped_storage().resize_() (#119647)
This makes storage resizing work with `backend=eager`, the next two PRs make it work for inductor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119647
Approved by: https://github.com/yf225
2024-02-13 19:03:28 +00:00
rzou
5e0ef84b01 [dynamo] Refactor install_global_once, remove usages of install_global_unsafe (#118100)
We split install_global_once into two APIs:
- `install_global_by_id(prefix, value) -> name`: installs a global if it hasn't
been installed yet
- `install_global(prefix, value) -> name`: always installs the global (and
  generates a unique name for it)

Then, we refactor most callsites of `install_global_unsafe` to one of
the previous. Some callsites cannot be refactored because we create the
global name first, do a lot of stuff with it, and then install it.

This fixes more test flakiness.

Test Plan:
- Existing tests; I can't reliably repro the flakiness
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118100
Approved by: https://github.com/ezyang, https://github.com/mlazos
2024-01-24 23:25:44 +00:00
rzou
af7cd5c32a [Dynamo] Install module globals per output_graph (#117998)
Fixes https://github.com/pytorch/pytorch/issues/117851

In tests, we ran into an issue where:
- In frame A, Dynamo would install a global
- We call reset()
- reset() did not delete the installed global due to a refcycle
- In frame B, Dynamo would re-use the same global
- Python gc ran, deleting the installed global, leading to the compiled
  version of frame B raising NameNotFound

This PR changes the following:
- module globals are now installed at a per-frame basis.
- renames install_global to install_global_unsafe: if the names are not
  unique and end up being re-used across frames, then we've got trouble.

Test Plan:
- I tested that this got rid of the test flakiness locally. I'm not sure
  how to easily write a test for this, because I don't actually know
  what the refcycle in the above is.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117998
Approved by: https://github.com/ezyang, https://github.com/anijain2305
2024-01-23 02:28:02 +00:00
voznesenskym
83e8a0721d Reland #111196 (take 4) "Support tensors as Dict keys" (#116934)
Fixes #ISSUE_NUMBER

See that PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116934
Approved by: https://github.com/ezyang, https://github.com/huydhn
2024-01-07 01:37:26 +00:00
PyTorch MergeBot
2dca3e99eb Revert "Support tensors as Dict keys Re-PR of #111196 (#116785)"
This reverts commit 1badad9ce9.

Reverted https://github.com/pytorch/pytorch/pull/116785 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/116785#issuecomment-1879592261))
2024-01-06 08:22:33 +00:00
voznesenskym
1badad9ce9 Support tensors as Dict keys Re-PR of #111196 (#116785)
This prepares the PR where we implement sets in terms of dicts.
To do so, rather than storing internally a dictionary that maps literals
to VariableTrackers, it stores (pretty much) a dictionary from VTs to VTs.
To do so, keys are wrapped in an opaque internal class _Hashable.
The Hashable class is opaque on purpose so that it fails hard if
if it inadvertently leaks back into user code.
We also found and fixed a number of latent bugs and inconsistencies
in the way dynamo checked what can be a dict key. More generally, we
make much clearer what are the things that need to be modified to add
a new supported key type to Dicts.

Fixes [#107595](https://www.internalfb.com/tasks?t=107595)
Fixes [#111603](https://www.internalfb.com/tasks?t=111603)

Re-PR of https://github.com/pytorch/pytorch/pull/111196 sadly due to reverts, we could not reuse @lezcano's original PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116785
Approved by: https://github.com/mlazos
2024-01-06 03:35:35 +00:00
Michael Lazos
8eb7f6276b Ensure wrapping subclasses with as_subclass is supported (#116091)
As title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116091
Approved by: https://github.com/pmeier, https://github.com/zou3519
2023-12-20 14:37:08 +00:00
zhxchen17
f78f23d753 [export] Turn off output value from sources for export. (#115442)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115442
Approved by: https://github.com/tugsbayasgalan
2023-12-12 22:41:23 +00:00
Michael Lazos
fbeca60b1f Remove replace_all and make VTs mutable (#113725)
1.  Removes calls to `replace_all` and `clone` and makes VTs mutable.
2. Properly handles Tuple Iterator mutation. Previously TupleIterator variables would only be properly reconstructed if they were advanced at least once in a frame. On calls to `next`, the source information would be lost (due to constructing a new iterator without using builder), which would ensure that during codegen the variable would be reconstructed from scratch. Now that VTs are mutated, the source is never lost, so we need to properly track mutation and handle it by replaying calls to `next` at the end of the modified bytecode.
3. Added test for checking iadd side effects, this was missing in our unit test coverage.
4. Fixed two incorrect sources, DelayGraphBreakVariable, and UserMethodVariable both relied on setting the source to AttrSource(parent, name) at the callsite of `var_getattr`.
5. Fixed a bug in inplace adding for lists, it would set the resulting VariableTracker's source to `None` which would utilize a different reconstruct path in codegen. Now this is handled explicitly by reconstructing vars when allow_cache=`False`, so that during side effect replay, the mutated var is correctly updated.

In subsequent PRs:
* Refactoring side effect tracking to be significantly simpler (I think we only need an `is_modified` flag)
* Refactor `next_variables` iterator to match the signature of `next`
* Remove all references to `options` in the code
* Refactor VTs representing mutable collections to implement their own mutation update handling
* Remove clone and/or make it specific to lists for creating slices
* Add mutation tracking/replay for sets
* Add mutation tracking/replay for iter.py
* Removing setting source in builder (it's set at the top level after a var is returned)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113725
Approved by: https://github.com/jansel
2023-12-10 09:31:21 +00:00
PyTorch MergeBot
3e47e3f441 Revert "[export] Fix graph output mismatch issue with constant outputs. (#115280)"
This reverts commit 622688fab9.

Reverted https://github.com/pytorch/pytorch/pull/115280 on behalf of https://github.com/atalman due to ghfirst issue when importing, will reland this PR ([comment](https://github.com/pytorch/pytorch/pull/115280#issuecomment-1847903624))
2023-12-08 22:10:03 +00:00
PyTorch MergeBot
3dab46fe19 Revert "[export] Dont skip output caching for now. (#115374)"
This reverts commit fd79995fd6.

Reverted https://github.com/pytorch/pytorch/pull/115374 on behalf of https://github.com/atalman due to ghfirst issue when importing, will reland this PR ([comment](https://github.com/pytorch/pytorch/pull/115374#issuecomment-1847899901))
2023-12-08 22:06:21 +00:00
zhxchen17
fd79995fd6 [export] Dont skip output caching for now. (#115374)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115374
Approved by: https://github.com/tugsbayasgalan
2023-12-07 20:31:30 +00:00
zhxchen17
622688fab9 [export] Fix graph output mismatch issue with constant outputs. (#115280)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115280
Approved by: https://github.com/tugsbayasgalan
2023-12-07 06:11:08 +00:00
Jason Ansel
aa70e31610 [dynamo] Fix MutableSideEffects returning alias (#115095)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115095
Approved by: https://github.com/yanboliang
2023-12-05 19:01:03 +00:00
PyTorch MergeBot
5d170fce29 Revert "Support tensors as Dict keys (#111196)"
This reverts commit b0805fa5d0.

Reverted https://github.com/pytorch/pytorch/pull/111196 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but it is failing internally. I will provide the details there ([comment](https://github.com/pytorch/pytorch/pull/111196#issuecomment-1813410149))
2023-11-15 23:08:00 +00:00
lezcano
b0805fa5d0 Support tensors as Dict keys (#111196)
This prepares the PR where we implement sets in terms of dicts.
To do so, rather than storing internally a dictionary that maps literals
to VariableTrackers, it stores (pretty much) a dictionary from VTs to VTs.
To do so, keys are wrapped in an opaque internal class `_Hashable`.
The Hashable class is opaque on purpose so that it fails hard if
if it inadvertently leaks back into user code.

We also found and fixed a number of latent bugs and inconsistencies
in the way dynamo checked what can be a dict key. More generally, we
make much clearer what are the things that need to be modified to add
a new supported key type to Dicts.

Fixes https://github.com/pytorch/pytorch/issues/107595
Fixes https://github.com/pytorch/pytorch/issues/111603
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111196
Approved by: https://github.com/jansel
2023-11-14 19:14:03 +00:00
Jason Ansel
3914566c73 [dynamo] Refactor OrderedDict to dict (#113234)
In Python3 all dicts are ordered.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113234
Approved by: https://github.com/oulgen, https://github.com/lezcano
2023-11-08 09:27:08 +00:00
Jason Ansel
843a8ecd24 [dynamo] Remove VariableTracker.add_options (#111725)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111725
Approved by: https://github.com/voznesenskym
ghstack dependencies: #111306, #111415
2023-11-07 19:55:19 +00:00
Jason Ansel
9664190952 [dynamo] Eagerly install guards (#111415)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111415
Approved by: https://github.com/voznesenskym
ghstack dependencies: #111306
2023-11-07 19:55:19 +00:00
Jez Ng
b8ac5bbcbd [dynamo] Enable typechecking for bytecode_transformation.py (#112561)
As part of this diff, I have upgraded the `python_version` config setting to 3.11. `bytecode_transformation.py` (and a few other files) have functions using APIs only available in Python 3.11+. Those APIs are gated by a sys.version_info check in their typeshed .pyi files. So setting the min version to 3.11 allows those functions to typecheck properly.

An alternative is to make the relevant types Any:

```
if sys.version_info >= (3, 11):
    _Positions = dis.Positions
else:
    _Positions = Any
```

However, with python_version = 3.8, that means we're not getting any useful typechecking signal when encountering values of type _Position.

Changing the python_version to 3.11 does mean that we will stop typechecking codepaths that run only on lower versions, but that seems a small price to pay. It does also mean that we won't catch code that uses newer APIs without the appropriate version check, but again, not sure this has much of an impact.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112561
Approved by: https://github.com/ezyang
2023-11-04 19:36:27 +00:00
Jez Ng
413baa1b25 [dynamo] Enable typechecking for codegen.py (#111992)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111992
Approved by: https://github.com/Skylion007, https://github.com/eellison
ghstack dependencies: #111894
2023-10-26 04:54:16 +00:00
Michael Lazos
1d9a7f9e43 [Reland] TensorWithTFOverride inheritance from TensorVariable (#111766)
Accidentally merged https://github.com/pytorch/pytorch/pull/111730 with ghstack, so relanding

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111766
Approved by: https://github.com/jansel
2023-10-23 04:33:16 +00:00