Commit Graph

206 Commits

Author SHA1 Message Date
Animesh Jain
e68d65dae2 [dynamo][cpp-guards] Differentiate dict guards wrt to guarding on key order (#124779)
We guard on key order
1) When a key is a non-constant object
2) When we actually need key order - like .values, .items etc

For dicts/OrderedDicts that do not require key order guarding, we just rely on usual `GuardManger + DictGetItemGuardAccessor`. This is faster than going through the `list(d.keys())` based design for OrderedDicts.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124779
Approved by: https://github.com/jansel
2024-04-25 08:20:35 +00:00
Xuehai Pan
7e1c98c171 [dynamo] support object.__setattr__(obj, name, value) (#124068)
Resolves #114964
Resolves #114966

- #114964
- #114966

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124068
Approved by: https://github.com/jansel
2024-04-17 15:57:14 +00:00
Animesh Jain
58afcd7b61 [dynamo][dict] Add UnspecializedNNModuleVariable to dict keys (#122812)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122812
Approved by: https://github.com/jansel
ghstack dependencies: #122943, #123877, #123878
2024-04-13 02:07:35 +00:00
Jason Ansel
6b0ba6bbd3 [dynamo] Improve constant-prop for regex/torch.__version__ (#123705)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123705
Approved by: https://github.com/anijain2305
ghstack dependencies: #123700
2024-04-12 19:03:13 +00:00
Jason Ansel
212e460dce [dynamo] Support custom __setattr__ on UserDefinedObjectVariable (#123318)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123318
Approved by: https://github.com/anijain2305
2024-04-07 21:06:52 +00:00
Jason Ansel
781e8d2201 [dynamo] Support __next__ on UserDefinedObjectVariable (#122565)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122565
Approved by: https://github.com/yanboliang
2024-03-31 19:00:03 +00:00
Brian Hirsh
2e44b12dd4 dynamo: handle DTensor.device_mesh.device_type (#118803)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118803
Approved by: https://github.com/wanchaol, https://github.com/yanboliang
2024-03-22 14:42:22 +00:00
Jason Ansel
477d154ffd [dynamo] Add missing _nonvar_fields annotations (#122219)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122219
Approved by: https://github.com/anijain2305
ghstack dependencies: #122218
2024-03-20 07:53:18 +00:00
Jason Ansel
153a01833b [dynamo] Optimize SourcelessBuilder (#122063)
Improves `benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py`
from 2.7s to 2.5s.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122063
Approved by: https://github.com/anijain2305
ghstack dependencies: #122039, #122043, #122055, #122058, #122060
2024-03-19 04:23:30 +00:00
Jason Ansel
2bec55c5f9 [dynamo] Remove VariableTracker.parents_tracker (#122058)
This is leftover from mutable variable tracker days and no longer needed.

Improves benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py
from 4.2s to 3.9s.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122058
Approved by: https://github.com/oulgen, https://github.com/anijain2305
ghstack dependencies: #122039, #122043, #122055
2024-03-19 04:23:24 +00:00
Jason Ansel
3c706bf483 [dynamo] Optimize BuiltinVariable (#122055)
Improves `benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py`
from 5.1s to 4.2s (compared to 2 PRs ago).

This works by precomputing (and caching) the parts of `BuiltinVariable.call_function` that don't depend on the values of args/kwargs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122055
Approved by: https://github.com/oulgen, https://github.com/anijain2305
ghstack dependencies: #122039, #122043
2024-03-19 04:23:20 +00:00
Jason Ansel
07caea5c12 [dynamo] Refactor COMPARE_OP and comparison builtins (#122043)
This removes the duplicate handling of comparison ops between symbolic_convert and bultin and refactors the handling to use the binop infrastructure.  This change regresses overheads a bit, but this is fixed in the next PR.

New test skips are variants of `type(e) is np.ndarray` previously falling back to eager.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122043
Approved by: https://github.com/anijain2305
ghstack dependencies: #122039
2024-03-19 04:23:17 +00:00
Jason Ansel
769ff86b91 [dynamo] Optimize COMPARE_OP (#122039)
Improves `benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py`
from 5.6 to 5.1s.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122039
Approved by: https://github.com/Skylion007, https://github.com/anijain2305
2024-03-19 04:23:14 +00:00
Jason Ansel
4034873a31 [dynamo] Optimize builtin handling (#122035)
Improves `benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py`
from 7.3s to 6.7s.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122035
Approved by: https://github.com/Skylion007
ghstack dependencies: #122032, #122033, #122034
2024-03-18 18:08:06 +00:00
Jason Ansel
6ca0323615 [dynamo] Optimize VariableTracker.__post_init__ (#122034)
Improves `benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py`
from 8.6s to 7.3s.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122034
Approved by: https://github.com/Skylion007
ghstack dependencies: #122032, #122033
2024-03-18 18:08:06 +00:00
Aaron Gokaslan
d55d803812 Add operator length hint support (#121495)
Seemed like an easy operator to squeeze into Python 2.3 . Added a simple test. Partially addresses #116396

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121495
Approved by: https://github.com/albanD
2024-03-08 19:08:33 +00:00
Yichen Yan
e50ded03a6 Use type check for also is_not (#113859)
Handle `is_not` for:

9647a251cb/torch/_dynamo/variables/builtin.py (L1314-L1317)

I noticed https://github.com/pytorch/pytorch/issues/111713 exists, I think it's no harm to land this first.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113859
Approved by: https://github.com/Skylion007
2024-03-06 23:12:42 +00:00
Jason Ansel
35004b8ab4 [dynamo] Fix handling of invalid args (#121110)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121110
Approved by: https://github.com/yanboliang
ghstack dependencies: #121106
2024-03-05 17:16:04 +00:00
Animesh Jain
8e4301077e [dynamo][comp-time] BuiltinVariableTracker - inspect signature only on failure (#121053)
Reduces the torch.compile(backend="eager") for this code by 1-2 seconds.
~~~
def fn(x):
    for _ in range(10000):
        # x = torch.sin(x)
        x = torch.ops.aten.sin(x)
        # x = sin(x)

    return x
~~~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121053
Approved by: https://github.com/jansel
2024-03-02 23:03:00 +00:00
Animesh Jain
5a53c0ff23 [dynamo][refactor] Rename LIST_LENGTH to SEQUENCE_LENGTH, separate DICT_LENGTH (#120721)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120721
Approved by: https://github.com/jansel
ghstack dependencies: #120520, #120590
2024-02-28 02:19:10 +00:00
Animesh Jain
e3d64c4d5d [dynamo] Desugar accumulate_grad, fix .grad handling (#120590)
Fixes https://github.com/pytorch/pytorch/issues/118435
Fixes https://github.com/pytorch/pytorch/issues/119906

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120590
Approved by: https://github.com/ezyang, https://github.com/jansel
ghstack dependencies: #120520
2024-02-27 10:12:26 +00:00
Jason Ansel
2fea475215 [dynamo] Refactor reconstruct() not to return anything (#120150)
This simplifies things slightly and avoids some bugs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120150
Approved by: https://github.com/yanboliang
2024-02-17 17:13:41 +00:00
Oguz Ulgen
62e5840b36 [Dynamo] Do not create TorchInGraphFunctionVariable for tags (#120005)
Fixes https://github.com/pytorch/pytorch/issues/119793

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120005
Approved by: https://github.com/yanboliang
2024-02-16 03:37:32 +00:00
ydwu4
b251bca205 [dynamo] inlining into __iter__ of user defined object (#119243)
Fixes #119198.

This PR make dynamo inline `__iter__` of a user defined object instead of creating a graph break. Also added a new test, which shows:
1. the loop is unrolled
2. the length of the loop is guarded when inlining `__iter__`
```python
class Mod:
    def __init__(self):
        self.a = [torch.randn(2, 2), torch.randn(2, 2)]

    def __iter__(self):
        return iter(self.a)

def f(mod):
    ret = []
    for x in mod:
        ret.append(x + 1)
    return ret
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119243
Approved by: https://github.com/jansel
2024-02-08 17:07:30 +00:00
Yanbo Liang
0f478d9d61 [Dynamo][15/N] Merge allow_in_graph/inline/skip trace rules check into trace_rule.lookup (#118971)
Finally we have this PR to merge allow_in_graph/inline/skip trace rules into ```trace_rules.lookup_inner```, where we can define and lookup trace rules at both function level and file level. Going forward, this is the central place that we define and consulte Dynamo trace rule for any function.
* ```trace_rules.looup``` is the API can return allow_in_graph, inline or skip.
* ```skipfiles.check``` is the API can return inline or skip, since we have multiple places that only do inline/skip check.
  *  I'll move ```skipfiles.check``` to ```trace_rules.check``` as one of the follow-ups.
* Both functions consulte ```trace_rules.lookup_inner``` to get the tracing rule.

To avoid a single big PR, I left a few items as the follow-ups:
* Remove ```skipfiles.py``` and merge the code into ```trace_rules.py```.
* We do double check in ```symbolic_convert.check_inlineable```, will refactor and simplify it. We should only do inline/skip check before generating ```SkipFilesVariable``` and ```UserFunctionVariable```.
* Rename ```SkipFilesVariable``` as ```SkipFunctionVariable```, since we only handle functions.
* The inline/skip reasons are not logged for some cases, since the new lookup framework doesn't always return inline/skip reasons. I'll refactor loggings to record the inline/skip reason in next step.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118971
Approved by: https://github.com/jansel
2024-02-07 05:15:39 +00:00
ydwu4
5410385c42 [dynamo] support comparing stream with constant (#119199)
Before the pr, we have a graph break for:
```python
def f():
    if torch.cuda.current_stream() is not None:
        return torch.randn(2, 2)
torch.compile(f, backend="eager", fullgraph=True)()
```
This pr supports comparson ops of StreamVariable and ConstantVariable by returning a constant.

It's safe to return a constant in this case becuase the StreamVariable is guarded by ID_MATCH when created.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119199
Approved by: https://github.com/yifuwang, https://github.com/anijain2305, https://github.com/jansel
2024-02-06 19:26:03 +00:00
Yanbo Liang
8ee9f26ce8 [Dynamo] Remove build_checkpoint_variable from call_getattr (#119236)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119236
Approved by: https://github.com/jansel
2024-02-06 16:59:40 +00:00
Edward Z. Yang
3c0c387429 Support symbolic min/max on unbacked SymInt (#118953)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118953
Approved by: https://github.com/ColinPeppler, https://github.com/aakhundov
2024-02-02 20:01:46 +00:00
lezcano
b1da929df9 Use SourcelesBuilder in BuiltinVariable (#118098)
This was failing when fetching a dictionary from a module

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118098
Approved by: https://github.com/peterbell10, https://github.com/anijain2305
ghstack dependencies: #117982
2024-02-02 14:37:23 +00:00
rzou
a16df1d85f [Dynamo] graph break on isinstance calls if we don't know the type (#118778)
If we can't figure out the python type of a VariableTracker, then the
isinstance call should graph break (instead of raising an error).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118778
Approved by: https://github.com/ydwu4
ghstack dependencies: #118768
2024-02-01 23:18:10 +00:00
Yanbo Liang
4fc4f5eb06 [Dynamo] Support tensor is not tensor (#118840)
Fixes Meta internal use case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118840
Approved by: https://github.com/yf225
2024-02-01 07:32:43 +00:00
laith sakka
8455447972 Support builtin callable with object arguments in dynamo (#118678)
Fix issue #117556

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118678
Approved by: https://github.com/anijain2305
2024-01-31 17:54:08 +00:00
Yanbo Liang
ca1d70632d [14/N][Dynamo] Make trace_rules.lookup only handle function + callable type (#118366)
Step by step changes to unblock #118264

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118366
Approved by: https://github.com/angelayi
2024-01-27 23:02:44 +00:00
Edward Z. Yang
d03173e88c Unify MYPYINDUCTOR and MYPY (#118432)
The original motivation for MYPYINDUCTOR was a faster type checking configuration that only checked a subset of files. With the removal of `follow_imports = ignore`, we are now able to use dmypy to do fast incremental typechecking, eliminating the need for this.

Perhaps erroneously, when I tee'ed up this PR I elected to delete the `follow_imports = skip` designations in the mypy-inductor.ini. This lead to a number of extra type error suppressions that I manually edited. You will need to review.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118432
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418
2024-01-27 17:23:20 +00:00
laith sakka
b47cf4182e Fix support non tensor inputs to operator.pos function (#118251)
Fixes #118231

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118251
Approved by: https://github.com/Skylion007, https://github.com/anijain2305
2024-01-25 20:37:40 +00:00
Guilherme Leobas
80cf0ce153 Enhance torch.vmap support from inside torch.compile (#116050)
This work rewrites vmap support in torch.compile by inlining most of
the frames into the existing FX graph. It also unlocks to PyTorch to
support features that were previously missing, such as keyword args.

Fixes: https://github.com/pytorch/pytorch/issues/114306

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116050
Approved by: https://github.com/zou3519
2024-01-22 17:53:45 +00:00
ydwu4
eba5d5485d [dynamo] make ConstantSource propagate through built-in ops for TensorVariable (#117704)
Fixes #117685.

This PR only makes ConstantSource perserved for built-in ops when we find all the inputs are either constant tensors or python constants.

 It doesn't fundamentally solve the problem of preserving ConstantSource information through all operators that's potentially can be constant folded.

For the following code in the issue:
```
class Bob(torch.nn.Module):
    def __init__(self, p, val) -> None:
        super().__init__()
        self.p = p
        self.y = torch.nn.Parameter(torch.tensor(val))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # This only looks dynamic but it's actually a constant value
        if get_y(self.y) < self.p:
            return torch.cat([x,x])
        else:
            return x
```
The graph exported looks like following:
```python
class GraphModule(torch.nn.Module):
    def forward(self, x):
        arg0: "f32[s0, s1]";

        arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        l_x_ = arg0

        # File: /home/yidi/local/pytorch/test/dynamo/test_export.py:1498 in forward, code: return torch.cat([x, x])
        cat = torch.cat([l_x_, l_x_]);  l_x_ = None
        return pytree.tree_unflatten([cat], self._out_spec)
```

Test Plan:
Added a new test for the given repro.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117704
Approved by: https://github.com/jansel, https://github.com/anijain2305
2024-01-18 20:18:34 +00:00
Animesh Jain
6e4e81a9ef [dynamo] Extend LazyVariableTracker to tuples (#117426)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117426
Approved by: https://github.com/lezcano, https://github.com/jansel
2024-01-18 15:51:28 +00:00
lezcano
4ba5318d3f [dynamo] Add DictView variable tracker (#108420)
This also starts a comparison pattern where we don't ask variables
what's their type, but what are their capabilities.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108420
Approved by: https://github.com/jansel
ghstack dependencies: #112252, #117630, #110524
2024-01-18 09:37:33 +00:00
lezcano
f4df0f061c Implement set in terms of dict (#110524)
This allows to heavily simplify the implementation of set, which was
"quite unique". Now we represent a set a as a dict where all its values
are None.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110524
Approved by: https://github.com/jansel
ghstack dependencies: #112252, #117630
2024-01-18 09:36:41 +00:00
Aaron Gokaslan
62496ffd0d [dynamo][easy]: Add support for operator.truth (#117463)
* This is an old builtin function equivalent to the bool constructor. it is easy enough to add support for.
* I also realized the tests were in the wrong class (the one reserved for testing default args) so I moved them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117463
Approved by: https://github.com/jansel
2024-01-14 19:08:31 +00:00
Aaron Gokaslan
bf27dd6df9 Add dynamo support for operator.abs (#117442)
A test case for operator.abs and allows for constant folding with it. Partially applies to #116396

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117442
Approved by: https://github.com/jansel, https://github.com/malfet
2024-01-13 21:38:55 +00:00
Aaron Gokaslan
1dd4813328 [BE][dynamo]: Add operator is and is not tests to dynamo tests (#116397)
Adds an operator that was unit not tested in our test suite - improves coverage. Inspired by looking into https://github.com/pytorch/pytorch/pull/116397 after @XuehaiPan brought up some issues with builtins in #116389

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116397
Approved by: https://github.com/albanD, https://github.com/jansel
2024-01-09 21:13:22 +00:00
voznesenskym
83e8a0721d Reland #111196 (take 4) "Support tensors as Dict keys" (#116934)
Fixes #ISSUE_NUMBER

See that PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116934
Approved by: https://github.com/ezyang, https://github.com/huydhn
2024-01-07 01:37:26 +00:00
PyTorch MergeBot
2dca3e99eb Revert "Support tensors as Dict keys Re-PR of #111196 (#116785)"
This reverts commit 1badad9ce9.

Reverted https://github.com/pytorch/pytorch/pull/116785 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/116785#issuecomment-1879592261))
2024-01-06 08:22:33 +00:00
voznesenskym
1badad9ce9 Support tensors as Dict keys Re-PR of #111196 (#116785)
This prepares the PR where we implement sets in terms of dicts.
To do so, rather than storing internally a dictionary that maps literals
to VariableTrackers, it stores (pretty much) a dictionary from VTs to VTs.
To do so, keys are wrapped in an opaque internal class _Hashable.
The Hashable class is opaque on purpose so that it fails hard if
if it inadvertently leaks back into user code.
We also found and fixed a number of latent bugs and inconsistencies
in the way dynamo checked what can be a dict key. More generally, we
make much clearer what are the things that need to be modified to add
a new supported key type to Dicts.

Fixes [#107595](https://www.internalfb.com/tasks?t=107595)
Fixes [#111603](https://www.internalfb.com/tasks?t=111603)

Re-PR of https://github.com/pytorch/pytorch/pull/111196 sadly due to reverts, we could not reuse @lezcano's original PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116785
Approved by: https://github.com/mlazos
2024-01-06 03:35:35 +00:00
Guoliang He
0159e3abbd [dynamo] add a handler for itertools_chain_from_iterable and test (#116849)
1. add a handler for itertools_chain_from_iterable
2. a test for itertools_chain_from_iterable

Fixes #116463

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116849
Approved by: https://github.com/ezyang
2024-01-05 15:14:18 +00:00
Aaron Gokaslan
bd10fea79a [BE]: Enable F821 and fix bugs (#116579)
Fixes #112371

I tried to fix as many of the bugs as I could, a few I could not figure out what the proper fix for them was though and so I left them with noqas.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116579
Approved by: https://github.com/ezyang
2024-01-01 08:40:46 +00:00
Xuehai Pan
3149e4a667 [dynamo] fix sum() function with start argument (#116389)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116389
Approved by: https://github.com/Skylion007, https://github.com/malfet
2023-12-27 20:42:27 +00:00
PyTorch MergeBot
e0e90bc0d4 Revert "[dynamo] fix sum() function with start argument (#116389)"
This reverts commit 3c9076f070.

Reverted https://github.com/pytorch/pytorch/pull/116389 on behalf of https://github.com/kit1980 due to Breaks Meta-internal tests, but the issue could have been caught on GitHub ([comment](https://github.com/pytorch/pytorch/pull/116389#issuecomment-1870556927))
2023-12-27 19:05:55 +00:00