Commit Graph

15 Commits

Author SHA1 Message Date
Benjamin Glass
43c9b4e0e6 Fix unintentional deduplication of returned tensors (#134726)
When CSE was used, returned tensors that had gone through identical
processing steps but were distinct from a data perspective were pruned
out of the graph.  This commit protects tensors which are directly
output from being pruned, and adds a test for this behavior.

Closes #88813 and #114344

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134726
Approved by: https://github.com/amjames, https://github.com/zou3519, https://github.com/bdhirsh
2024-09-04 23:42:56 +00:00
rzou
092349dcdd Never CSE aten.empty in the partitioner (#134703)
aten.empty is almost always fusible into its consumer, so we never CSE
it. This fixes a bug that looks like the following:

```py
@torch.library.custom_op("_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"})
def sin_cos(x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor) -> None:
    out_sin.copy_(x.sin())
    out_cos.copy_(x.cos())

@torch.compile
def f(x):
    out0 = torch.empty_like(x)
    out1 = torch.empty_like(x)
    sin_cos(x, out0, out1)
    return x.clone(), out0, out1

x = torch.randn(3, requires_grad=True)
f(x)
```

- cse would de-duplicate the empty nodes
- reinplacing would add an additional clone (because it can't write to
  both tensors at the same time)
- the clone lowers into a new buffer + a copy_ kernel
- the copy_ kernel is unnecessary because "empty" is special - all reinplacing needed was an additional
  buffer, it doesn't matter what the values are.

We could attempt to fix this on the reinplacing side but this seemed
better as a partitioner heuristic and the reinplacing fix is a bit more
tricky (we'd need to identify that the op never reads from the empty
node).

Test Plan:
- new test (the old number was 27, the new number is 21, so this PR
  helped).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134703
Approved by: https://github.com/yf225
ghstack dependencies: #134466, #134490, #134491
2024-08-29 13:51:19 +00:00
Brian Hirsh
4db368a475 make functorch CSE respect mutations as barriers (like fsdp.set_) (#132243)
Fixes https://github.com/pytorch/pytorch/issues/132200

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132243
Approved by: https://github.com/albanD, https://github.com/zou3519, https://github.com/yf225
2024-08-05 21:28:55 +00:00
Xuehai Pan
e7eeee473c [BE][Easy][14/19] enforce style for empty lines in import segments in torch/_[a-c]*/ and torch/_[e-h]*/ and torch/_[j-z]*/ (#129765)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129765
Approved by: https://github.com/ezyang
2024-07-31 10:42:50 +00:00
chilli
f9a7033194 Refactor partitioner and clean it up (#126318)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126318
Approved by: https://github.com/anijain2305
2024-05-17 06:15:00 +00:00
Xuehai Pan
73f0ecc1ac [BE] UFMT directory torch/_functorch (#123723)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123723
Approved by: https://github.com/Skylion007
2024-04-12 08:04:51 +00:00
Oguz Ulgen
5aab2b9acf Use graph.find_nodes in functorch (#122258)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122258
Approved by: https://github.com/jansel
ghstack dependencies: #121565, #122255, #122256, #122257
2024-04-07 18:51:22 +00:00
vfdev-5
ed20e9118b Fixed hash issue in fx_graph_cse (#119567)
Description:
- Fixed issue with hash collision for `hash((primals_2, 1.0)) == hash((primals_2, 1))`

Repro code:
```python
import torch
from torch._functorch.compile_utils import fx_graph_cse

def func(inpt, osize):
    size = inpt.shape[-1]
    s1 = size - 1
    s2 = size - 1.0
    scale = s2 / (osize - 1.0)
    inpt = torch.clamp(inpt, 0, s1)
    return scale * inpt

gms = []
def toy_backend(gm, _):
    gms.append(gm)
    return gm.forward

torch._dynamo.reset()
fn = torch.compile(backend=toy_backend, dynamic=True)(func)
t = torch.rand(3, 100)
out = fn(t, 50)
gm = gms[0]

print(gm.graph)
new_fx_g = fx_graph_cse(gm.graph)
print(str(new_fx_g))
```
Original graph
```
graph():
    %s0 : torch.SymInt [num_users=0] = placeholder[target=s0]
    %s1 : torch.SymInt [num_users=0] = placeholder[target=s1]
    %l_inpt_ : torch.Tensor [num_users=2] = placeholder[target=L_inpt_]
    %l_osize_ : torch.SymInt [num_users=1] = placeholder[target=L_osize_]
    %size : [num_users=1] = call_method[target=size](args = (%l_inpt_,), kwargs = {})
    %getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%size, 1), kwargs = {})
    %sub : [num_users=1] = call_function[target=operator.sub](args = (%getitem_1, 1), kwargs = {})
    %sub_1 : [num_users=1] = call_function[target=operator.sub](args = (%getitem_1, 1.0), kwargs = {})
    %sub_2 : [num_users=1] = call_function[target=operator.sub](args = (%l_osize_, 1.0), kwargs = {})
    %truediv : [num_users=1] = call_function[target=operator.truediv](args = (%sub_1, %sub_2), kwargs = {})
    %inpt : [num_users=1] = call_function[target=torch.clamp](args = (%l_inpt_, 0, %sub), kwargs = {})
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%truediv, %inpt), kwargs = {})
    return (mul,)
```
New wrong graph where `sub_2` is replaced incorrectly with `sub`:
```
graph():
    %s0 : torch.SymInt [num_users=0] = placeholder[target=s0]
    %s1 : torch.SymInt [num_users=0] = placeholder[target=s1]
    %l_inpt_ : torch.Tensor [num_users=2] = placeholder[target=L_inpt_]
    %l_osize_ : torch.SymInt [num_users=1] = placeholder[target=L_osize_]
    %size : [num_users=1] = call_method[target=size](args = (%l_inpt_,), kwargs = {})
    %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%size, 1), kwargs = {})
    %sub : [num_users=2] = call_function[target=operator.sub](args = (%getitem_1, 1), kwargs = {})
    %sub_2 : [num_users=1] = call_function[target=operator.sub](args = (%l_osize_, 1.0), kwargs = {})
    %truediv : [num_users=1] = call_function[target=operator.truediv](args = (%sub, %sub_2), kwargs = {})
    %inpt : [num_users=1] = call_function[target=torch.clamp](args = (%l_inpt_, 0, %sub), kwargs = {})
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%truediv, %inpt), kwargs = {})
    return (mul,)
```
With this PR the new graph is the following:
```
graph():
    %s0 : torch.SymInt [num_users=0] = placeholder[target=s0]
    %s1 : torch.SymInt [num_users=0] = placeholder[target=s1]
    %l_inpt_ : torch.Tensor [num_users=2] = placeholder[target=L_inpt_]
    %l_osize_ : torch.SymInt [num_users=1] = placeholder[target=L_osize_]
    %size : [num_users=1] = call_method[target=size](args = (%l_inpt_,), kwargs = {})
    %getitem_1 : [num_users=2] = call_function[target=operator.getitem](args = (%size, 1), kwargs = {})
    %sub : [num_users=1] = call_function[target=operator.sub](args = (%getitem_1, 1), kwargs = {})
    %sub_1 : [num_users=1] = call_function[target=operator.sub](args = (%getitem_1, 1.0), kwargs = {})
    %sub_2 : [num_users=1] = call_function[target=operator.sub](args = (%l_osize_, 1.0), kwargs = {})
    %truediv : [num_users=1] = call_function[target=operator.truediv](args = (%sub_1, %sub_2), kwargs = {})
    %inpt : [num_users=1] = call_function[target=torch.clamp](args = (%l_inpt_, 0, %sub), kwargs = {})
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%truediv, %inpt), kwargs = {})
    return (mul,)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119567
Approved by: https://github.com/eellison
2024-02-12 18:52:11 +00:00
Edward Z. Yang
9bce208dfb Replace follow_imports = silent with normal (#118414)
This is a lot of files changed! Don't panic! Here's how it works:

* Previously, we set `follow_imports = silent` for our mypy.ini configuration. Per https://mypy.readthedocs.io/en/stable/running_mypy.html#follow-imports, what this does is whenever we have an import to a module which is not listed as a file to be typechecked in mypy, we typecheck it as normal but suppress all errors that occurred in that file.
* When mypy is run inside lintrunner, the list of files is precisely the files covered by the glob in lintrunner.toml, but with files in excludes excluded.
* The top-level directive `# mypy: ignore-errors` instructs mypy to typecheck the file as normal, but ignore all errors.
* Therefore, it should be equivalent to set `follow_imports = normal`, if we put `# mypy: ignore-errors` on all files that were previously excluded from the file list.
* Having done this, we can remove the exclude list from .lintrunner.toml, since excluding a file from typechecking is baked into the files themselves.
* torch/_dynamo and torch/_inductor were previously in the exclude list, because they were covered by MYPYINDUCTOR. It is not OK to mark these as `# mypy: ignore-errors` as this will impede typechecking on the alternate configuration. So they are temporarily being checked twice, but I am suppressing the errors in these files as the configurations are not quite the same. I plan to unify the configurations so this is only a temporary state.
* There were some straggler type errors after these changes somehow, so I fixed them as needed. There weren't that many.

In the future, to start type checking a file, just remove the ignore-errors directive from the top of the file.

The codemod was done with this script authored by GPT-4:

```
import glob

exclude_patterns = [
    ...
]

for pattern in exclude_patterns:
    for filepath in glob.glob(pattern, recursive=True):
        if filepath.endswith('.py'):
            with open(filepath, 'r+') as f:
                content = f.read()
                f.seek(0, 0)
                f.write('# mypy: ignore-errors\n\n' + content)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118414
Approved by: https://github.com/thiagocrepaldi, https://github.com/albanD
2024-01-27 02:44:11 +00:00
Peter Bell
bbd5b935e4 Use pytree.tree_leaves everywhere (#112324)
This changes all the instances I could find of `tree_flatten(...)[0]` or
`x, _ = tree_flatten` to use `tree_leaves`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112324
Approved by: https://github.com/lezcano
ghstack dependencies: #112327, #112323
2023-10-30 03:39:04 +00:00
Kazuaki Ishizaki
6d7744ca46 Fix typo under torch/_functorch directory (#111067)
This PR fixes typo the the of comments and exception messages in files under `torch/_functorch` directory.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111067
Approved by: https://github.com/Skylion007
2023-10-11 23:09:36 +00:00
XiaobingSuper
afd621ddde inductor: fix CSE issue when have symbolic shape input at the freezing path (#105651)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105651
Approved by: https://github.com/jgong5, https://github.com/eellison
2023-07-26 08:07:31 +00:00
Richard Zou
4068c5467d [Reland] Move functorch/_src to torch/_functorch (#88756) (#90091)
This will be the last disruptive functorch internals change.

Why are we moving these files?
- As a part of rationalizing functorch we are moving the code in
functorch/_src to torch/_functorch
- This is so that we can offer the functorch APIs as native PyTorch APIs
(coming soon) and resolve some internal build issues.

Why are we moving all of these files at once?
- It's better to break developers all at once rather than many times

Test Plan:
- wait for tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90091
Approved by: https://github.com/anijain2305, https://github.com/ezyang
2022-12-03 14:17:15 +00:00
PyTorch MergeBot
218d9c6e09 Revert "Move functorch/_src to torch/_functorch (#88756)"
This reverts commit 52bc5c1cfe.

Reverted https://github.com/pytorch/pytorch/pull/88756 on behalf of https://github.com/clee2000 due to broke imports in tests 52bc5c1cfe https://github.com/pytorch/pytorch/actions/runs/3574742513/jobs/6010814968 probably a landrace
2022-11-29 17:17:11 +00:00
Richard Zou
52bc5c1cfe Move functorch/_src to torch/_functorch (#88756)
This will be the last disruptive functorch internals change.

Why are we moving these files?
- As a part of rationalizing functorch we are moving the code in
functorch/_src to torch/_functorch
- This is so that we can offer the functorch APIs as native PyTorch APIs
(coming soon) and resolve some internal build issues.

Why are we moving all of these files at once?
- It's better to break developers all at once rather than many times

Test Plan:
- wait for tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88756
Approved by: https://github.com/ezyang
2022-11-29 13:55:42 +00:00