Previously the decomposition would upcasts inputs to fp32. This led to a slowdown compared to eager which would run in fp16. We also tried keeping the bmm in fp16, and the upcasting for the epilogue but that led to worse numerics because the bmm in eager would do the epilogue all in fp32 without a downcast in the bmm accumulator.
Fix for https://github.com/pytorch/pytorch/issues/137897
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137904
Approved by: https://github.com/ngimel
Like the previous two PRs, this is doing the rebinding and binding computation, just in FakeTensorUpdater. FakeTensorUpdater modifies FX graph in place so its usage pattern is slightly different, but still pretty short.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124314
Approved by: https://github.com/IvanKobzarev, https://github.com/lezcano
ghstack dependencies: #124310
The original motivation for MYPYINDUCTOR was a faster type checking configuration that only checked a subset of files. With the removal of `follow_imports = ignore`, we are now able to use dmypy to do fast incremental typechecking, eliminating the need for this.
Perhaps erroneously, when I tee'ed up this PR I elected to delete the `follow_imports = skip` designations in the mypy-inductor.ini. This lead to a number of extra type error suppressions that I manually edited. You will need to review.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118432
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418
This expands the reinplacing pass to allow reinplacing view-scatter operations.
e.g. if our python code is:
```
a = view1(inp)
b = view2(a)
b.copy_(src)
```
this generates a functionalized graph like:
```python
a = view1(inp)
a_updated = view2_scatter(a, src)
inp_updated = view1_scatter(inp, a_updated)
```
First, the `canonicalize_view_scatter_ops` step rewrites the functionalized graph
in the form:
```python
inp_updated = _generalized_scatter(inp, src, [view1, view2])
a_updated = view1(inp_updated)
```
I then register `_generalized_scatter` as a normal inplacable op which can be
handled by the pre-existing mechanism. Since we've fused the two scatter ops into one,
the reinplacing pass sees only one user of `inp` which allows the entire operation to be
reinplaced if desired (and I add heuristics that sometimes choose not to reinplace).
Finally, there is a decomposition step which decomposes out-of-place or in-place
`_generalized_scatter` operations either back into view_scatter operations, or
into the version with mutations. When introducing mutations, the reinplaced
version is equivalent to the original mutation:
```
a = view1(inp)
b = view2(a)
b.copy_(src)
```
Or when out-of-place we end up with a minor restructuring of the graph:
```
a = view1(inp)
tmp = view2_scatter(a, src)
inp_updated = view1_scatter(inp, tmp)
a_updated = view1(inp_updated)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116899
Approved by: https://github.com/lezcano
ghstack dependencies: #116898, #117121
This adds a function `statically_known_true` for `SymBool` that works
like inductor's `is_expr_static_and_true`. That is, it tries to simplify the
expression to a constant or returns `False` if it cannot be simplified.
This is useful in cases that can be optimized if the condition is met,
otherwise it doesn't effect correctness so we can avoid adding guards.
I also use this new function in inductor for `FakeTensorUpdater` and
`remove_noop_pass` which both generated unexpected guards previously.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117359
Approved by: https://github.com/lezcano
* Enable PERF402. Makes code more efficient and succinct by removing useless list copies that could be accomplished either via a list constructor or extend call. All test cases have noqa added since performance is not as sensitive in that folder.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115505
Approved by: https://github.com/malfet