Commit Graph

227 Commits

Author SHA1 Message Date
Wenzhe Xue
8dfac7b887 Update fx.pass.graph_drawer usage doc to draw fx graph (#95534)
Previous usage gave this error:
```
f.write(g.get_dot_graph().create_svg())
TypeError: write() argument must be str, not bytes
```

pydot has function to save to different types, e.g. `save_svg()`. I updated the usage doc working code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95534
Approved by: https://github.com/ezyang
2023-02-27 19:27:29 +00:00
Renfei Chen
c44a733018 Fix split_module bug (#95493)
Summary: Title, the mapping currently has lots of unused keys due to the condition or always return True, but it will not affect the correctness.

Test Plan: N/A

Differential Revision: D43579510

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95493
Approved by: https://github.com/Skylion007
2023-02-27 19:11:49 +00:00
Renfei Chen
0d2e91573e Reorder the Fx execution order to in-time get_attr rather than putting all get_attr ahead (#95014)
Summary:
Basically today we:
[getattr....getattr, call partition1, call parition2]
this makes getattr just in time:
so [getattr, call partition1, getattr, call partition 2 ..]

Test Plan:
CMF and MAI test result:
https://fb.quip.com/K5J9A7G246Ox

Differential Revision: D43376080

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95014
Approved by: https://github.com/angelayi
2023-02-21 20:05:30 +00:00
Aaron Gokaslan
67d9790985 [BE] Apply almost all remaining flake8-comprehension checks (#94676)
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
2023-02-12 01:01:25 +00:00
Xuehai Pan
5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00
Wei-Sheng Chin
c5c7687b74 Allow FakeTensorProp to run on graphs traced with some None inputs (#94569)
Without this tiny change in `torch/_subclasses/fake_tensor.py`, the added test may fail with
```
TypeError: cannot create weak reference to 'NoneType' object
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94569
Approved by: https://github.com/ezyang
2023-02-10 20:38:22 +00:00
Angela Yi
d990ddadd5 [fx] Fix matching args (#94375)
To match nodes within the graph, the matcher currently flattens the arguments and compares each argument against each other. However, if it believes that a list input contains all literals, it will not flatten the list and will instead compare the list directly against each other. It determines if a list is a literal by checking if the first element is a node. However this doesn't work in some cases (like the test cases I added).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94375
Approved by: https://github.com/SherlockNoMad
2023-02-10 17:37:57 +00:00
Xuehai Pan
a229b4526f [BE] Prefer dash over underscore in command-line options (#94505)
Preferring dash over underscore in command-line options. Add `--command-arg-name` to the argument parser. The old arguments with underscores `--command_arg_name` are kept for backward compatibility.

Both dashes and underscores are used in the PyTorch codebase. Some argument parsers only have dashes or only have underscores in arguments. For example, the `torchrun` utility for distributed training only accepts underscore arguments (e.g., `--master_port`). The dashes are more common in other command-line tools. And it looks to be the default choice in the Python standard library:

`argparse.BooleanOptionalAction`: 4a9dff0e5a/Lib/argparse.py (L893-L895)

```python
class BooleanOptionalAction(Action):
    def __init__(...):
            if option_string.startswith('--'):
                option_string = '--no-' + option_string[2:]
                _option_strings.append(option_string)
```

It adds `--no-argname`, not `--no_argname`. Also typing `_` need to press the shift or the caps-lock key than `-`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94505
Approved by: https://github.com/ezyang, https://github.com/seemethere
2023-02-09 20:16:49 +00:00
Aaron Gokaslan
8fce9a09cd [BE]: pyupgrade Python to 3.8 - imports and object inheritance only (#94308)
Apply parts of pyupgrade to torch (starting with the safest changes).
This PR only does two things: removes the need to inherit from object and removes unused future imports.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-07 21:10:56 +00:00
Eli Uriegas
567e6152da Revert "[inductor] fix crash issue when input is a view tensor (#90150)" (#94329)
Had to provide a merge conflict resolution due to conflicts with https://github.com/pytorch/pytorch/pull/94118

This was causing issues with internal tests that look similar to:
```
in clone_preserve_strides
    x.size(), x.stride(), x.storage_offset()
AttributeError: 'KeyedJaggedTensor' object has no attribute 'size'
```

See https://fburl.com/testinfra/nc0du2sp for more information

This reverts commit #90150

@jansel can you help @blzheng with re-landing this as a co-development diff?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94329
Approved by: https://github.com/jansel
2023-02-07 20:45:58 +00:00
Elias Ellison
e4f11e01bd [Fake Tensor] Allow fake meta by default, delete unused ctor args (#93993)
Two small changes that I'm bundling together because one of them needs to touch fbcode and I'm not sure how to do stacked diffs + internal changes + land before release cut.

Remove allow_meta from ctor, and allow by default: we should be able to trace through meta with fake tensors, so in some senses it's a bit weird to expose to user to disallow this. However, it's still useful debug wise to error from time to time, so I've added an option to the config that will get back previous behavior.

Remove `throw_on_data_dependent_ops=True`: this was intended as a temporary behavior as we were smoothing things turning on the erroring. There are no uses anywhere of `throw_on_data_dependent_ops=False` I could find.

These are technically backward-incompatble, but fake tensor is new since the last release / in a private namespace, and I don't want to release it with baggage that would be hard to remove later.

Fix for https://github.com/pytorch/pytorch/issues/92877.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93993
Approved by: https://github.com/bdhirsh, https://github.com/ezyang
2023-02-03 09:23:38 +00:00
blzheng
a71395dd88 [inductor] fix crash issue when input is a view tensor (#90150)
Fix the crash failure mentioned in https://github.com/pytorch/pytorch/issues/93460

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90150
Approved by: https://github.com/jgong5, https://github.com/jansel
2023-02-03 04:54:14 +00:00
PyTorch MergeBot
5d259425fc Revert "[inductor] fix crash issue when input is a view tensor (#90150)"
This reverts commit b11ec270ba.

Reverted https://github.com/pytorch/pytorch/pull/90150 on behalf of https://github.com/clee2000 due to failing test_inplace_unsqueeze3 (__main__.CPUReproTests) https://github.com/pytorch/pytorch/actions/runs/4074618739/jobs/7020199369 b11ec270ba, marking as landrace cuz all jobs are green on pr
2023-02-02 17:06:34 +00:00
blzheng
b11ec270ba [inductor] fix crash issue when input is a view tensor (#90150)
Fix the crash failure mentioned in https://github.com/pytorch/pytorch/issues/93460

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90150
Approved by: https://github.com/jgong5, https://github.com/jansel
2023-02-02 12:49:26 +00:00
PyTorch MergeBot
db466ae057 Revert "[Modes] Add assert that the mode isn't already on the stack (#90770)"
This reverts commit 702838637d.

Reverted https://github.com/pytorch/pytorch/pull/90770 on behalf of https://github.com/DanilBaibak due to Break internal build
2023-01-12 16:44:29 +00:00
samdow
702838637d [Modes] Add assert that the mode isn't already on the stack (#90770)
Redo of #89726 on a clean PR, thanks @voznesenskym for the first draft!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90770
Approved by: https://github.com/ezyang
2023-01-11 15:19:43 +00:00
Chen Lai
fd213c3231 Match get_attr when compare node (#91657)
The pattern can't be matched if one attribute is `_param_constant1` and the other is `_param_constant0`

Large graph:
```
        # call_function  addmm_default      aten.addmm.default  (_param_constant1, ph_0, _tensor_constant0)  {}
```

Pattern graph
```
        # call_function  addmm_default      aten.addmm.default  (_param_constant0, ph_0, _tensor_constant0)  {}
```

Differential Revision: [D42316574](https://our.internmc.facebook.com/intern/diff/D42316574/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91657
Approved by: https://github.com/SherlockNoMad
2023-01-09 08:10:55 +00:00
jjsjann123
192a11d49c refactor the dfs cyclic search from recursion to iterative approach (#91042)
Follow up on PR #86511

Python's 1000 limit on recursion depth is not practical for us to run cyclic check on larger graphs. This refactor avoids that issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91042
Approved by: https://github.com/kit1980
2022-12-20 23:15:30 +00:00
Mergen Nachin
5e3bc1975b Add any_chain() in upstream (#90949)
Summary: I need any chain. Current chain is logical AND.

Test Plan: arc lint, follow-up diffs use it.

Differential Revision: D42078837

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90949
Approved by: https://github.com/angelayi
2022-12-16 04:09:10 +00:00
XiaobingSuper
ffa89033c5 TorchDynamo: always convert tensor to fake tensor at fake_mode path for ShapeProp (#90685)
This PR will fix https://github.com/pytorch/torchdynamo/issues/1978, for HF models, there is always report a ShapeProp error, the root cause is that we use fake tensor mode to do the ShapeProp, but for **torch.ones**, it always gets a none fake tensor and introduces an operation with non-fake tensors with fake tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90685
Approved by: https://github.com/ezyang, https://github.com/jansel
2022-12-13 06:59:43 +00:00
Angela Yi
02eb0bdbc1 [fx] Added better tests to pass infra (#90432)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90432
Approved by: https://github.com/SherlockNoMad
2022-12-09 21:43:18 +00:00
Sergii Dymchenko
f51f6aa387 Fix non-existing parameters in docstrings (#90505)
Continuation after https://github.com/pytorch/pytorch/pull/90163.

Here is a script I used to find all the non-existing arguments in the docstrings (the script can give false positives in presence of *args/**kwargs or decorators):

_Edit:_
I've realized that the indentation is wrong for the last `break` in the script, so the script only gives output for a function if the first docstring argument is wrong. I'll create a separate PR if I find more issues with corrected script.

``` python
import ast
import os
import docstring_parser

for root, dirs, files in os.walk('.'):
    for name in files:
        if root.startswith("./.git/") or root.startswith("./third_party/"):
            continue
        if name.endswith(".py"):
            full_name = os.path.join(root, name)
            with open(full_name, "r") as source:
                tree = ast.parse(source.read())
                for node in ast.walk(tree):
                    if isinstance(node, ast.FunctionDef):
                        all_node_args = node.args.args
                        if node.args.vararg is not None:
                            all_node_args.append(node.args.vararg)
                        if node.args.kwarg is not None:
                            all_node_args.append(node.args.kwarg)
                        if node.args.posonlyargs is not None:
                            all_node_args.extend(node.args.posonlyargs)
                        if node.args.kwonlyargs is not None:
                            all_node_args.extend(node.args.kwonlyargs)
                        args = [a.arg for a in all_node_args]
                        docstring = docstring_parser.parse(ast.get_docstring(node))
                        doc_args = [a.arg_name for a in docstring.params]
                        clean_doc_args = []
                        for a in doc_args:
                            clean_a = ""
                            for c in a.split()[0]:
                                if c.isalnum() or c == '_':
                                    clean_a += c
                            if clean_a:
                                clean_doc_args.append(clean_a)
                        doc_args = clean_doc_args
                        for a in doc_args:
                            if a not in args:
                                print(full_name, node.lineno, args, doc_args)
                            break

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90505
Approved by: https://github.com/malfet, https://github.com/ZainRizvi
2022-12-09 21:43:09 +00:00
Angela Yi
a076bdb357 [fx] Copy codegen in legalize_graph (#90023)
Test Plan: CI

Differential Revision: D41666330

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90023
Approved by: https://github.com/SherlockNoMad
2022-12-07 21:09:38 +00:00
Tran Le
b769005924 [fx][passes] Implement annotate getitem node FX passes (#90237)
Summary: One common cause of jit unscriptability issue is loss of node type annotations on local names after one or several FX transform(s). One way to improve the type coverage is to eagerly annotate the type for `getitem` nodes from its parent sequence node. This diff introduces an fx pass to do that.

Test Plan:
```
buck2 test //caffe2/test:fx_experimental
```

Reviewed By: xush6528

Differential Revision: D41749744

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90237
Approved by: https://github.com/xush6528
2022-12-06 23:18:55 +00:00
Ram Rachum
77f9b2e8bf Fix exception causes in fx, nn and onnx packages (#90134)
This is a continuation of #90118

@kit1980
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90134
Approved by: https://github.com/kit1980
2022-12-06 04:34:58 +00:00
Michael Voznesensky
41c3b41b92 Use dynamo fake tensor mode in aot_autograd, move aot_autograd compilation to lowering time [Merger of 89672 and 89773] (#90039)
After all of the preparatory commits, this is a subset of the
changes in https://github.com/pytorch/pytorch/pull/89392 that actually
change us to propagating fake tensors to backends.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

This is the merger of Ed's PR #89672, which is a rewrite of an older PR of mine (#89392), with CI Fixes on top of it (#89773)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90039
Approved by: https://github.com/ezyang
2022-12-05 01:56:50 +00:00
PyTorch MergeBot
4648baa911 Revert "Use dynamo fake tensor mode in aot_autograd, move aot_autograd compilation to lowering time [Merger of 89672 and 89773] (#90039)"
This reverts commit ef0c7ec958.

Reverted https://github.com/pytorch/pytorch/pull/90039 on behalf of https://github.com/clee2000 due to broke xla tests ef0c7ec958 https://github.com/pytorch/pytorch/actions/runs/3606308473/jobs/6077646142
2022-12-04 21:57:30 +00:00
Michael Voznesensky
ef0c7ec958 Use dynamo fake tensor mode in aot_autograd, move aot_autograd compilation to lowering time [Merger of 89672 and 89773] (#90039)
After all of the preparatory commits, this is a subset of the
changes in https://github.com/pytorch/pytorch/pull/89392 that actually
change us to propagating fake tensors to backends.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

This is the merger of Ed's PR #89672, which is a rewrite of an older PR of mine (#89392), with CI Fixes on top of it (#89773)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90039
Approved by: https://github.com/ezyang
2022-12-03 01:19:55 +00:00
Tran Le
0e1fcc8aa8 [FX] Add type annotation to getitem node before split_module (#88510)
Summary: Some nodes lost the type annotation during `split_module`, causing the submodels to be un-scriptable. This is because compiler always infer Tensor type, which is wrong for non-Tensor types. We attempt to infer type annotation for `getitem` node to improve scriptability.

Test Plan:
```
buck2 test //caffe2/test:fx_experimental
```

Differential Revision: D41037819

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88510
Approved by: https://github.com/xush6528
2022-11-18 23:19:14 +00:00
Kazuaki Ishizaki
1cd6ebe095 Fix typos in messages under torch (#89049)
This PR fixes typos of messages in `.py` files under torch directory.
Only in `torch/onnx/symbolic_opset16.py`, fix a typo in comment to make the operator name correct.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89049
Approved by: https://github.com/lezcano
2022-11-17 04:18:14 +00:00
Riley Dulin
f73d9a79fe [torch][fx] Fix PassManager to not use a class variable mutable list (#89108)
Summary:
I found a confusing bug in the PassManager that only happens
when you instantiate one multiple times: it will use old passes and
constraints!

This occurs because the class-level declarations initialize it to an empty list,
but the problem is that class initializers only run once, and are creating class
variables. This means the same empty list was being reused every time, except
after the first time it isn't empty.

The empty list has to be created in `__init__` newly each time or else it'll be shared.
Note that this is the same type of bug as using an empty list as a default parameter, where
it'll reuse the same list pointer and not make it empty each time.

The better way to do this is with either:
* An immutable default parameter like an empty tuple, that you create a new list from: `self.passes = list(passes)`
* Use None and then create the empty list inside `__init__`

I chose the latter as it's less likely to cause a behavior change due to the changed default.

Note that for immutable values like `False` and `1` this doesn't apply as you can't mutate that
value for everyone.

Test Plan:
Added a test to ensure that the pass state is not saved.
Without my change, this test would fail as it would run all of the `2 * x` passes first,
then all of the `3 * x` passes.

Differential Revision: D41327056

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89108
Approved by: https://github.com/angelayi
2022-11-17 02:43:33 +00:00
Wei-Sheng Chin
86b7aa26f0 Fix FakeTensorProp on Module with Parameters or Buffers (#88700)
In `FakeTensorMode.__torch_dispatch__`, the output is now always computed by meta kernels in
```python
        try:
            with in_kernel_invocation_manager(self):
                r = func(*args, **kwargs)  # <----- "r" can be a real tensor.
        except NotImplementedError as not_implemented_error:
            # no meta kernel registered, fallback to kernel for the device
            if not self.allow_fallback_kernels:
                raise not_implemented_error
            return run_fallback_kernel(self, func, args, kwargs, not_implemented_error)

        return self.wrap_meta_outputs_with_default_device_logic(r, func, args, kwargs)
```
For example, I observed a CPU tensor is generated when executing `aten.addmm` when running `FakeTensorProp`. Therefore, I'd like to allow `FakeTensorMode` to wrap real tensor as `FakeTensor` during the computation. Does this PR look a good direction to fix this problem? If yes, I can go ahead and add some tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88700
Approved by: https://github.com/eellison, https://github.com/ezyang
2022-11-11 03:49:29 +00:00
Sherlock Huang
1d82eba98b PatternMatcher supports matching list-typed args (#88656)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88656
Approved by: https://github.com/jerryzh168
2022-11-08 21:05:18 +00:00
Kurt Mohler
ee28b865ee Deprecate TypedStorage, its derived classes, and all of their public methods (#85303)
Part of #85302

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85303
Approved by: https://github.com/ezyang
2022-11-08 18:11:01 +00:00
jjsjann123
af09270e10 nvprims bookend non compute (#88457)
Cherry-pickeding: https://github.com/csarofeen/pytorch/pull/2099

1. enabling bookend non-compute-ops pass on nvfuser
2. fixing bookend op check on intermediate tensor as partition inputs
3. python tests added for: `getitem` special handling bookend_non_compute removal
4. patching dfs by excluding dfs within partition to avoid going over recursion limitation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88457
Approved by: https://github.com/SherlockNoMad
2022-11-08 12:06:35 +00:00
Angela Yi
91a4039842 [exir][fx] PassManager error handling (#88520)
Summary:
* Added an error message for when the result is not a PassResult
* Modified the error handling to capture exceptions that happen in the check() function
* consolidated inplace_wrapper and pass_result_wrapper

Test Plan: CI

Differential Revision: D40950135

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88520
Approved by: https://github.com/SherlockNoMad
2022-11-07 18:42:41 +00:00
Mor Tzur
6575174dcb [fx2ait] fixes for AITSplitter (#87805)
Summary: propagate lower settings to AITSplitter settings.

Reviewed By: yinghai, qxy11

Differential Revision: D40568216

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87805
Approved by: https://github.com/yinghai
2022-11-04 20:18:08 +00:00
Brian Hirsh
b5a925ff2e propagate .meta info when replacing subgraphs in fx (#87255)
Fixes https://github.com/pytorch/torchdynamo/issues/1708

Our FX subgraph partitioner works by taking all of the original output nodes from a subgraph, and replacing it with a new `call_module` node in the graph.

If the original subgraph outputs had fake tensors and other metadata stored in their `.meta` attribute though, then this information was getting lost when we spliced in the subgraph.

Losing metadata on an FX graph also seems like an easy trap to fall into, so I'm wondering if there are any better guardrails that we can add. I ended up fixing in this PR by adding an optional kwarg to propagate meta info directly in the `fx.Node.replace_all_uses_with`, just because propagating metadata seems like a pretty core thing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87255
Approved by: https://github.com/wconstab, https://github.com/SherlockNoMad
2022-11-02 14:36:46 +00:00
Shiyan Deng
fb1586fbcb Make a copy of the submodule inputs (#87899)
Summary: There might be inplace ops in the model that would change the saved inputs. To avoid that, we save a deepcopy version.

Test Plan: CI

Differential Revision: D40771290

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87899
Approved by: https://github.com/houseroad
2022-11-01 22:42:04 +00:00
Ivan Yashchuk
0eea05b11e Remove "prims_nvfuser" backend for TorchDynamo (#88083)
Removing "prims_nvfuser" backend according to the discussion in https://github.com/pytorch/torchdynamo/pull/1281#discussion_r979468355.

cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88083
Approved by: https://github.com/ezyang
2022-11-01 03:09:37 +00:00
Ivan Yashchuk
ae4fbac819 Enable nvprims.transpose fusions for nvFuser (#86967)
This PR allows transposes to be fused with other operations. If a fusion group is formed only from operations that just manipulate metadata in PyTorch (transpose, view, etc.) then this group is not sent to nvFuser.
On top of that if we have converted to `nvprims` but then decided to not form a fusion group we modify the graph use `prim.impl_aten` attribute instead of calling `prim(*args, **kwargs)` that has a higher overhead.

cc @kevinstephano @jjsjann123
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86967
Approved by: https://github.com/jjsjann123, https://github.com/SherlockNoMad
2022-10-26 17:00:07 +00:00
Ivan Yashchuk
72f446b9bc Remove getitem special handling in the partitioner (#87073)
This special handling of getitem unnecessary splits fusions at functions with tuple outputs.

Example script:
```py
import torch
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport
from torch.fx.experimental.proxy_tensor import make_fx

def func(x):
    xx = torch.ops.nvprims.add(x, 1)
    var, mean = torch.ops.nvprims.var_mean(x, correction=0)
    var_cos = torch.ops.nvprims.cos(var)
    mean_sin = torch.ops.nvprims.sin(mean)
    return torch.ops.nvprims.add(var_cos, mean_sin)

a = torch.randn(5, 3, 3, device="cuda")
gm = make_fx(func)(a)
gm.graph.print_tabular()

supported_ops = NvfuserPrimOperatorSupport()
partitioner = CapabilityBasedPartitioner(
    gm, supported_ops, allows_single_node_partition=False
)
partitions = partitioner.propose_partitions()
print(partitions)
partitioned_graph = partitioner.fuse_partitions(partitions)
partitioned_graph.graph.print_tabular()
```
Output on master:
```py
opcode         name       target                       args              kwargs
-------------  ---------  ---------------------------  ----------------  -----------------
placeholder    x_1        x_1                          ()                {}
call_function  add        nvprims.add.default          (x_1, 1)          {}
call_function  var_mean   nvprims.var_mean.main        (x_1, [0, 1, 2])  {'correction': 0}
call_function  getitem    <built-in function getitem>  (var_mean, 0)     {}
call_function  getitem_1  <built-in function getitem>  (var_mean, 1)     {}
call_function  cos        nvprims.cos.default          (getitem,)        {}
call_function  sin        nvprims.sin.default          (getitem_1,)      {}
call_function  add_1      nvprims.add.default          (cos, sin)        {}
output         output     output                       (add_1,)          {}
[{cos, sin, add_1}, {var_mean, add, getitem, getitem_1}]
opcode         name       target                       args                    kwargs
-------------  ---------  ---------------------------  ----------------------  --------
placeholder    x_1        x_1                          ()                      {}
call_module    fused_1    fused_1                      (x_1,)                  {}
call_function  getitem_2  <built-in function getitem>  (fused_1, 0)            {}
call_function  getitem_3  <built-in function getitem>  (fused_1, 1)            {}
call_module    fused_0    fused_0                      (getitem_2, getitem_3)  {}
output         output     output                       (fused_0,)              {}
```
Output with this PR:
```
[{var_mean, add_1, cos, sin, add, getitem_1, getitem}]
opcode       name     target    args        kwargs
-----------  -------  --------  ----------  --------
placeholder  x_1      x_1       ()          {}
call_module  fused_0  fused_0   (x_1,)      {}
output       output   output    (fused_0,)  {}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87073
Approved by: https://github.com/jjsjann123, https://github.com/SherlockNoMad
2022-10-26 14:18:46 +00:00
Soof Golan
874a94ce94 Fix tensor.stride() type hint (#84177)
`tensor.stride()` now hints at tuple of variable length instead of tuple with constant length of 1

Fixes #84176

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84177
Approved by: https://github.com/Chillee
2022-10-25 04:43:10 +00:00
Kazuaki Ishizaki
d80a5f9a96 Fix typo under torch directory (#87274)
This PR fixes typo in .md files under torch directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87274
Approved by: https://github.com/albanD
2022-10-21 14:22:20 +00:00
Sherlock Huang
e271e823c7 Avoid calling logging.basicConfig (#86959)
Fixes https://github.com/pytorch/pytorch/issues/85952

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86959
Approved by: https://github.com/xwang233, https://github.com/davidberard98
2022-10-17 16:45:21 +00:00
jjsjann123
f903f1ab34 Patching getitem in partitioner (#86713)
1. rejecting getitem operator in backends fusion query getitem is merged in a special post partition pass, backends that takes getitem shouldn't affect the logic
2. added test for failing cases

Fixes #86698

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86713
Approved by: https://github.com/SherlockNoMad
2022-10-12 07:50:46 +00:00
jjsjann123
2cb330ab15 Acyclic partition patch (#86511)
Fixes #86159 and #86108

Refactored graph partition to check for cyclic dependency on each partition merge, instead of relying on a pre-baked dependency map.

The previous implementation suffers from not updating dependency on existing partition. When a fusion happens, the updated dependency map needs to be propagated to all nodes in the graph, so each node in a partition shares an identical dependency set. Previous implementation suffers from the not identifying cyclic dependency in issue #86159.

Updated implementation does a cyclic check on partitioned graph before attempting a merge of two partitions.

- [x] python repro added with cyclic dependency after partition `TestFXGraphPasses.forward12`
- [x] fix dependency map with updated implementation using cyclic check

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86511
Approved by: https://github.com/SherlockNoMad
2022-10-10 23:48:52 +00:00
Sherlock Huang
2fec853c87 Fix SubgraphMatcher for case of no anchor found (#86421)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86421
Approved by: https://github.com/jerryzh168
2022-10-07 02:05:42 +00:00
Richard Zou
a4ff07f197 Stop modifying the global logger on import functorch (#86147)
Fixes https://github.com/pytorch/pytorch/issues/85952

`logging.basicConfig` modifies the global logger which affects other
programs. importing a package should generally be side-effect free so
this PR gets rid of that call.

Test Plan:
- tested locally
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86147
Approved by: https://github.com/ezyang
2022-10-04 02:33:54 +00:00
Sherlock Huang
5547c6aa4e Match kwargs in SubgrpahMatcher (#85617)
Pattern node and target node must have identical kwargs now...

Use envvar `LOGLEVEL=INFO` to turn on the logging message for easier debugging...

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):
* __->__ #85617
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85617
Approved by: https://github.com/jerryzh168, https://github.com/davidberard98
2022-09-26 21:26:07 +00:00
Sherlock Huang
a8add2b92f Support matching Args for SubgraphMatcher (#85456)
Subgraph matcher now handles the matching of non-Node arguments.

Here are the 4 cases
- pn is Node, gn is Node: this go through the regular _match_node() function
- pn is Noed, gn is not a Node: this is a match if only pn is a placeholder op
- pn is not Node, gn is Node: this is a no match case
- pn is not a Node, gn is not a Node: this will go through the argument comparison.

With this change
```
def target(x):
    return foo(x, 3)

def pattern(x, y):
    return foo(x, y)
```

is a match

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85456
Approved by: https://github.com/jerryzh168
2022-09-24 20:06:48 +00:00
Renfei Chen
4befe45084 [FX] Add one option to maintain the FX graph execution order after splitting_module (#85188)
Summary:
{F770932209}

Given the original execution order and the node dependency relationship (note that the same dependency order could generate multiple execution order, which refers to “Topological Order”), after reunion, we could find the new execution order of the new GraphModule is different from the original one which is not what we want.
For example, let’s assume that NewLeaf_1 is EmbeddingLookup (Calling EmbeddingLookup is awaitable, we will keep executing the following nodes rather than waiting for the result until we have to know the lookup result), NewLeaf_4 is the node where we HAVE to get the lookup result to interact with the NewLeaf_3. So NewLeaf_1 will launch a lookup kernel and all2all communication stream to distribute the result to all ranks. In the meantime, we want to keep executing NewLeaf_2 and NewLeaf_3 to avoid meaningless waiting. However, given the new execution order, we have to wait for the lookup kernel and all2all communication to be finished since the next node NewLeaf_4 needs the result, until then we can execute NewLeaf_2, etc. It cannot leverage the advantage of parallel computation and communication stream and will hurt the QPS a lot.
So while constructing the GraphModule, we have to change from the topological order to the original order

Test Plan:
Unit test

Not sure how to add tests in FX as there's no TARGETS, so I added in the TorchRec folder

Differential Revision: D39567314

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85188
Approved by: https://github.com/SherlockNoMad
2022-09-23 23:21:54 +00:00
Sherlock Huang
34296e2f4c SubgraphMatcher remove invalid matches (#85444)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85444
Approved by: https://github.com/rkindi
2022-09-22 02:59:11 +00:00
Elias Ellison
8bd9fe3f49 Changes to prepare for fake tensors on in functorch by default (#84432)
Fixes some errors you run into in dynamo when turning on fake tensors. I'm waiting on flipping the switch because I need to also get some fixes into dynamo + do benchmarking.

I could manually turn off fake tensors in functorch in dynamo, and then turn it on here if requested, although the changes here are pretty minimal.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84432
Approved by: https://github.com/Chillee
2022-09-08 04:29:30 +00:00
Wei Wei
31ef8ddb8c add option to remove passes (#84425)
Summary:
Add a remove_pass method in pass_manager to provide user option to remove any pass.

Reviewed By: wushirong

Differential Revision: D39080077

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84425
Approved by: https://github.com/yinghai
2022-09-07 17:21:27 +00:00
Qiming Lu
e71370064c Improvements to FX Minimizer (#83833)
Summary: This diff improves the FX Minimizer for better error reports, and fixes a few other issues.

Test Plan: CI

Differential Revision: D38900309

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83833
Approved by: https://github.com/yuhc, https://github.com/Chillee
2022-09-01 18:39:26 +00:00
Horace He
85931eaa6b Rename fake_result to val (#84331)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84331
Approved by: https://github.com/ezyang
2022-08-31 17:44:18 +00:00
Sungmin Cho
bf67589915 Escape curly brackets in FxGraphDrawer _typename (#83604)
Summary:
Encountered `Error: bad label format` from dot (i.e. graphviz) when benchmarking models that have dict-like structure.

The root cause was that curly brackets were not properly escaped, like this example P522499127 (unescaped curly brackets in target= string)

This diff insert the fix in FxGraphDrawer, since many of these graph generation codes rely on that class.

(Modified summary before exporting to GitHub PR)

Test Plan:
```
CUDA_VISIBLE_DEVICES=7 buck run mode/opt -c python.package_style=inplace //hpc/new/models/feed/benchmark:feed_lower_benchmark -- --model-name={INSERT IFR QE MODEL NAME HERE} --batch-iter 100 --batch-size 768 --num-gpu 1 --lower-presets {INSERT ITS PRESET}
```

Will not encounter dot errors after this diff.

(Modified test plan before exporting to GitHub PR)

Reviewed By: yinghai

Differential Revision: D38758827

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83604
Approved by: https://github.com/yinghai, https://github.com/jianyuh
2022-08-31 15:15:21 +00:00
Isaac Hoffman
20018aa766 modify split_by_tags to retain output order (#84136)
Summary: Currently `split_by_tags` determines submodule output order by iterating over `used_in_main`. Since this is a `Set`, insertion order is not retained so we run into problems with submodule output order being "randomized" & inconsistent between splits. By using `Dict[Node, None]` we can implement `used_in_main` as an ordered set so that output order is consistent when splitting the same model.

Test Plan: CI

Differential Revision: D39039268

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84136
Approved by: https://github.com/houseroad
2022-08-30 20:36:33 +00:00
Zhengxu Chen
a402e100be [fx] Make wrapped_fn also work for non-mutating passes. (#84232)
Summary: Before the change, wrapped_fn should only take mutating passes, but we don't actually have any way to detect whether a pass is mutating before running it. To make this an abstraction without involving any precondition depending on PassManager run, we could just relax the precondition to take any kind of passes, and conditionally return the original pass based on the pass result.

Test Plan: eyes

Reviewed By: qihqi, angelayi

Differential Revision: D39086343

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84232
Approved by: https://github.com/angelayi
2022-08-30 01:16:58 +00:00
Angela Yi
352da6de6b [fx][pass] Fix type of exception (#84094)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84094
Approved by: https://github.com/SherlockNoMad
2022-08-29 16:55:59 +00:00
PyTorch MergeBot
1945d28f58 Revert "[fx][pass] Fix type of exception (#84094)"
This reverts commit eb2fa2e042.

Reverted https://github.com/pytorch/pytorch/pull/84094 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally
2022-08-29 16:41:09 +00:00
Angela Yi
eb2fa2e042 [fx][pass] Fix type of exception (#84094)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84094
Approved by: https://github.com/SherlockNoMad
2022-08-26 22:34:14 +00:00
Nan Xiao
c47e0450f8 [fbia] Keep Track of full qualified name before and after remote sharding (#83889)
Summary: track qualname changes in embedding sharding & FX split, and compose target qualname in the end of FBIA transform stage, so we can use the qualname mapping in XL materialize stage

Test Plan:
CI/CD

with DISABLE_XLEBB_MATERIALIZATION = True
https://fburl.com/fblearner/a8yljbux

with DISABLE_XLEBB_MATERIALIZATION = False
https://fburl.com/fblearner/2nvi0dam

Reviewed By: lliu315gt

Differential Revision: D38772525

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83889
Approved by: https://github.com/houseroad
2022-08-24 01:15:25 +00:00
Shirong Wu
fc470cf980 Back out "Support regex-style matching for Any and Oneof (#82853)" (#83922)
Reviewed By: hl475

Differential Revision: D38945806

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83922
Approved by: https://github.com/hl475
2022-08-24 00:17:46 +00:00
Angela Yi
89072177e1 [fx][pass infra] Adding error catching (#83933)
Example:

```
======================================================================
ERROR: test_pass_manager_error (fx.test_pass_infra.TestPassManager)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/angelayi/Projects/pytorch/torch/fx/passes/infra/pass_manager.py", line 285, in __call__
    res = fn(module)
  File "/Users/angelayi/Projects/pytorch/test/fx/test_pass_infra.py", line 164, in pass_fail
    raise RuntimeError("bad")
RuntimeError: bad

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/angelayi/Projects/pytorch/test/fx/test_pass_infra.py", line 170, in test_pass_manager_error
    pm(traced_m)
  File "/Users/angelayi/Projects/pytorch/torch/fx/passes/infra/pass_manager.py", line 289, in __call__
    raise RuntimeError(msg) from e
RuntimeError: An error occured when running the 'pass_fail' pass after the following passes: ['replace_add_with_mul_pass', 'replace_mul_with_div_pass']
```

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83933
Approved by: https://github.com/SherlockNoMad
2022-08-23 23:56:50 +00:00
Brian Hirsh
8db04c1113 reinplace pass: special handling for view_scatter ops (#83846)
There is already special handling in the reinplacing pass for removing `{view}_scatter` ops, but there is another case that needs special handling. In this code:
```
         def f():
             a = torch.zeros(4, 4, 4)
             a[:, 2:] = torch.ones(4, 2, 4)
             return a
```

Tracing normally with `make_fx()` gives you:
```

def forward(self):
    zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
    ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
    slice_tensor = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
    slice_tensor_1 = torch.ops.aten.slice.Tensor(slice_tensor, 1, 2, 9223372036854775807);  slice_tensor = None
    copy__default = torch.ops.aten.copy_.default(slice_tensor_1, ones);  slice_tensor_1 = ones = None
    return zeros
```
Functionalizing it gives you:

```
def forward(self):
    zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
    ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
    slice_tensor = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
    slice_tensor_1 = torch.ops.aten.slice.Tensor(slice_tensor, 1, 2, 9223372036854775807);  slice_tensor = None
    slice_tensor_2 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
    slice_scatter_default = torch.ops.aten.slice_scatter.default(slice_tensor_2, ones, 1, 2, 9223372036854775807);  slice_tensor_2 = ones = None
    slice_scatter_default_1 = torch.ops.aten.slice_scatter.default(zeros, slice_scatter_default, 0, 0, 9223372036854775807);  zeros = slice_scatter_default = None
    return slice_scatter_default_1
```

Notice that there are not any functional ops to directly re-inplace! What actually happened is that functionalization turned the `copy_()` into a `copy()`, but the out-of-place `copy()` operator gets optimized away because it's a no-op (when the input and output metadata are the same, `out = copy(a, b)` just returns `b`).

What we actually want is to replace this line:
```
slice_scatter_default = torch.ops.aten.slice_scatter.default(slice_tensor_2, ones, 1, 2, ...);
```
with this:
```
new_slice = torch.ops.aten.slice.Tensor(slice_tensor_2, 1, 2, ...);
_ = torch.ops.aten.copy_.default(new_slice, ones)
```

In the above, we're taking a fresh slice of the "base" tensor, and performing a `copy_()` on the slice, adding back what functionalization removed.

We actually need to create a fresh "slice" node, because we're not guaranteed that one already exists in the graph (technically there should be one, but it might have been DCE'd by the time we hit re-inplacing)

I also updated the docs for re-inplacing to more closely match the order of the logic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83846
Approved by: https://github.com/ezyang
2022-08-23 17:13:58 +00:00
Brian Hirsh
75ec7b7547 reinplace pass: bugfix for output node replacement (#83845)
Cleaned up some of the arg replacement logic to use tree_map, so it handles FX nodes that have nested containers.

See the added test: when you write a function that returns a list, the `output` node in the FX graph shows up as having `node.args = tuple(immutable_list(...))`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83845
Approved by: https://github.com/ezyang
2022-08-23 17:13:58 +00:00
Alex Beloi
3c6c39e66e [fx] refactor fba_passes into FBAPassManagerBuilder (#83268)
Summary:
This diff integrate FBAPassManagerBuilder as the primary orchestrator of FBA-FX passes

Reviewed By: jfix71, dborkovic

Differential Revision: D38186354

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83268
Approved by: https://github.com/dborkovic
2022-08-22 06:54:18 +00:00
Brian Hirsh
e9e7363854 reinplacing pass fixes for torchbench + huggingface (#83626)
I'm testing out turning on re-inplacing + functionalization by default with the AOTAutograd + eager backend on torchbench + huggingface models. This PR contains a few bug fixes from turning re-inplacing on:

(1) Handle more gracefully when FakeTensorMode is already turned on when you call reinplace

(2) More robust detection for when an inplace variant of an op exists (the dumb bug was that `pow.Scalar` doesn't have an inplace variant, even though there are several overloads of `pow_`. None of them are eligible though

(3) Avoid re-inplacing when it would require resizing the input buffer. This isn't allowed, because inplace ops aren't allowed to resize their inputs.

For the last one, I gave the two main examples in more detail in the comments. Important cases are:
```
# This should not be re-inplaced at all; the op broadcasts, so this would require resizing the self tensor
torch.add(tensor[1, 4], tensor[4, 4])

# This should not be re-inplaced, because the inplace and out-of-place variants of the op return different dtypes
torch.ge(a, b)
# However, this means that today when functionalization functionalists a `torch.ge_(a, b)` call, reinplacing won't properly de-functionalize it. I mentioned that optimization is worth adding later in the comments
```

(4) There's some logic around keeping `storage_to_nodes` up to date when we see a view op: if we re-inplace `out = a.add(...)`, and later in the program we encounter a "later_node",`out.view(..)`, and need to replace it with `a.view(...)`, then we need to update some metadata structures. I had to fix that logic: specifically, if "later_node" isn't a dispatcher op, (e.g. if it's an FX output node), I wasn't properly handling the case where the node's fake_meta info was not a tensor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83626
Approved by: https://github.com/ezyang
2022-08-19 23:30:45 +00:00
Sherlock Huang
39e6238788 Support regex-style matching for Any and Oneof (#82853)
pseudo.any is a wildcard node that can be matched with any fx node with arbitrary number of inputs and outputs.
For example, to match relu followed by one fx node:
```
    def pattern(a):
        y = a.relu()
        z = torch.ops.pseudo.any(y)
        return z
```

pseudo.oneof is a special node that can be matched with a fx node whose target is in the permissible list.
`targets` must be be a list of qualified name for operators, e.g. ["operator.add", "torch.sigmoid",
"torch.ops.aten.foo", "torch.ops.prims.bar"]

For example, using following pattern with pseudo.oneof
```
    def pattern(a):
        y = a.relu()
        z = torch.ops.pseudo.oneof(y, targets=["relu", "torch.sigmoid", "operator.add"])
        return z
```

It will have 3 matches in the following function
```
    def forward(y):
        z = y.relu()
        x = z.relu()    # first match

        x = x.relu()
        x = torch.sigmoid(x)    # second match

        x = x.relu()
        return x + 1    # third match
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82853
Approved by: https://github.com/ezyang
2022-08-12 18:43:13 +00:00
Sherlock Huang
2ca721cda5 An improved version of subgraph matcher (#82090)
This new version of subgraph matcher further supports
- optionally match with pattern's placeholder and output nodes
- patterns with multiple outputs
- filtering out non-containing matches
- filtering out overlapping matches

TODOs:
- [x] Update replace_pattern() to use this matcher
- [x] Fix cases with identical anchor
- [x] Introduce wildcard matching, such Any, OneOf
- [ ] Improve node comparer to match args and kwargs values
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82090
Approved by: https://github.com/ezyang
2022-08-12 03:32:09 +00:00
Sergii Dymchenko
a0b3854548 Change seperate -> separate (#83056)
One instance was caught by Meta-internal "exact-word-misspell" linter in D38505529.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83056
Approved by: https://github.com/huydhn, https://github.com/seemethere
2022-08-09 23:11:34 +00:00
Horace He
51bbf6329a Improved legalize_graph pass in FX (#82874)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82874
Approved by: https://github.com/jamesr66a
2022-08-07 00:13:17 +00:00
Shirong Wu
4ae40d74ac Back out "Add an op_lowering_disallow_list in fx splitter base class. (#82288)" (#82750)
Summary:
Revert since this breaks BC test
More context:
failing test
https://www.internalfb.com/.../fblearner/details/361780349/
issue report thread
https://fb.workplace.com/groups/2211200152361974/permalink/2303690223112966/

Test Plan: All unit test

Differential Revision: D38399966

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82750
Approved by: https://github.com/yinghai
2022-08-05 02:15:00 +00:00
Brian Hirsh
d362b8e9e6 reland "add a reinplacing FX pass (#80897)" (#82407)
fixes #81457
fixes #81216
fixes #81212
fixes #81207
fixes #81206
fixes #81218
fixes #81203
fixes #81202
fixes #81214
fixes #81220
fixes #81205
fixes #81200
fixes #81204
fixes #81221
fixes #81209
fixes #81210
fixes #81215
fixes #81217
fixes #81222
fixes #81211
fixes #81201
fixes #81208

As part of this PR I'm also re-enabling all of the functionalization tests that got marked as flaky in CI (they're not actually flaky - I think they got marked because a PR that should have changed their expect-test output made it to master without the changes. I'll let CI run on this PR to confirm though).

reland of https://github.com/pytorch/pytorch/pull/80897
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82407
Approved by: https://github.com/ezyang
2022-08-02 18:03:29 +00:00
Shirong Wu
09059d9148 integrate plugin (#82395)
Differential Revision: D38162861

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82395
Approved by: https://github.com/frank-wei
2022-08-02 00:41:36 +00:00
Angela Yi
e06d1029f7 [fx] Minor modifications to pass infra (#82485)
* Made PassBase calls optionally return PassResult since some passes
  might want to base inplace.
* Added additional documentation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82485
Approved by: https://github.com/SherlockNoMad
2022-08-01 20:10:01 +00:00
Ying Zhang
a71d0e882c Add an op_lowering_disallow_list in fx splitter base class. (#82288)
Summary: ATT, so that we can control not to lower some specific ops.

Test Plan: Tested together with the next diff in stack.

Differential Revision: D38188836

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82288
Approved by: https://github.com/mikeiovine, https://github.com/khabinov
2022-07-28 05:19:33 +00:00
PyTorch MergeBot
df36ccbd81 Revert "add a reinplacing FX pass (#80897)"
This reverts commit 3ef7a6921d.

Reverted https://github.com/pytorch/pytorch/pull/80897 on behalf of https://github.com/malfet due to broke windows trunk tests, see 3ef7a6921d
2022-07-27 22:32:03 +00:00
Brian Hirsh
3ef7a6921d add a reinplacing FX pass (#80897)
Adds a "reinplacing" FX transform, that goes through an FX graph and tries to convert out-of-place op calls into inplace calls whenever possible.

Followups from this PR include:
- Set up torch bench, and run the whole torchbench suite using AOTAutograd + functionalize + rein placing transforms to surface any issues (this is what I'm currently working on). Right now, I have some basic unit tests just to sanity check that the general logic makes sense.
- Add any missing inplace ops. This is mostly the `*_scatter*` ops, e.g. `diagonal_scatter_`, because these ops will commonly show up an FX graph after running functionalization.

The criteria for when you can swap an op `b = a.add(...)` with `a.add_(...)` is:
(1) An inplace variant of the operator with the same schema needs to exist (`aten.add` -> `aten.add_`)
(2) `a` (**or any of its aliases**) can't be used as an input to any other operators later on in the graph
(3) `a` can't be one of the inputs to the entire graph. It also can't be an **alias** of any of the inputs ***

*** One thing to note: (3) means that we can't technically guarantee that we'll get back **all** memory usage that we lost from functionalization. Functionalization converts input mutations into out-of-place calls, and then adds a `copy_()` to the end of the graph to preserve semantics.

I added logic to handle `copy_()` in this PR because it it's a pretty important optimizations in the context of `functionalization()`: any program that performs input mutations will have a `copy_()` in it after running functionalization.

There are some examples in the test file, but I think staring at an example of where re-inplacing is/isn't allowed to run is helpful:
```
// Before functionalization
def foo(a):
    tmp1 = a.add_(1)
    tmp2 = a.add(2)

// After functionalization
def foo(a)
    tmp1 = a.add(1)
    tmp2 = a.add(2)
    ....
    a.copy_(tmp1)

// After re-inplacing
def foo(a)
    // first add() is safe to re-inplace even though a is a program input,
    // because a's data is overwritten later by a copy_()
    tmp1 = a.add_(1)
    // second add() is NOT safe to re-inplace, because:
    // (1) a and tmp1 are aliased. Note that they weren't aliased in the original program,
             but they are now that we've done some re-inplacing.
    // (2) tmp1 is used as an input later in the program
    tmp2 = a.add(2)
    ....
    a.copy_(tmp1)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80897
Approved by: https://github.com/ezyang
2022-07-27 19:11:15 +00:00
Sherlock Huang
dc3c1ade4b Some fixes for FX pass with nvFuser backend (#81911)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81911
Approved by: https://github.com/jjsjann123, https://github.com/IvanYashchuk, https://github.com/davidberard98
2022-07-22 19:49:33 +00:00
Edward Z. Yang
3c2c2cc947 cudagraphs dynamo backend (#80566)
This backend handles cases where the preexisting cuda graphs
implementation from dynamo is unsound/has errors.

Requires this functorch bug fix: https://github.com/pytorch/functorch/pull/935

Signed-off-by: Edward Z. Yang <ezyangfb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80566
Approved by: https://github.com/ngimel, https://github.com/wconstab
2022-07-22 14:06:07 +00:00
Shangdi Yu
c52ee6dc0a CSE Pass and common pass Tests (#81742)
Test cases for CSE Pass and common passes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81742
Approved by: https://github.com/SherlockNoMad
2022-07-22 03:45:09 +00:00
Sherlock Huang
43e7fee764 [Reland] Recursively print graph module and its submodule (#81639)
ghstack-source-id: fcfc024c440981ee3fe3537a5816089eadf2cc13
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81080

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81639
Approved by: https://github.com/ezyang
2022-07-21 16:58:25 +00:00
Shangdi Yu
7c5dac5228 Dialect agnostic CSE Pass (#81530)
Fixes comments in https://github.com/pytorch/pytorch/pull/81512

- banned ops is an input to the pass
- update the fx/readme.md to include this file for better discoverability
- use make_fx in torch repo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81530
Approved by: https://github.com/SherlockNoMad
2022-07-20 00:56:41 +00:00
PyTorch MergeBot
4035a53cca Revert "Recursively print graph module and its submodule (#81080)"
This reverts commit fe7262329c.

Reverted https://github.com/pytorch/pytorch/pull/81080 on behalf of https://github.com/DanilBaibak due to Break internal build
2022-07-18 14:46:26 +00:00
Sherlock Huang
fe7262329c Recursively print graph module and its submodule (#81080)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81080
Approved by: https://github.com/ezyang
2022-07-18 01:19:03 +00:00
Sherlock Huang
d625637c7c Include aten.where.self in NvFuserOperatorSupport (#81436)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81436
Approved by: https://github.com/davidberard98
2022-07-16 03:29:27 +00:00
Shangdi Yu
938643b8bc CSE_Pass (#81512)
Migrate the CSE pass in functorch to pytorch

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81512
Approved by: https://github.com/angelayi
2022-07-15 02:32:48 +00:00
Angela Yi
3d0b0b2f9b [fx] PassManager changes (#80531)
PassManager is a class used to run multiple passes on a given graph module.

Class Attributes
* `passes: List[Callable]`: A list of callable passes
* `constraints: List[Callable]`: A list of constraints
* `run_checks_after_each_pass`: Flag for running checks each pass

Class Methods:
* `__call__(graph_module: DispatchGraphModule)`:
    * Runs the passes based on the list of passes until the graph stops changes, or until `steps` number of times.
    * Each time a pass is run, it will check that the graph module still maintains the required invariants by calling `check()` and will lint the graph to check that it’s well formed if the flag `run_checks_after_each_pass` is set.
* `check(graph_module: DispatchGraphModule)`: Runs various checks on the given graph module to make sure that it contains the needed data for passes
* `add_check(check: Callable)`: Adds the `check` function to the given pass manager instance
* `add_constraint(constraint: Callable)`: Adds a constraint to the current list of constraints

We can create a PassManager and run it by doing:
```
PassManager(passes=[pass1, pass2])(graph_module)
```

Differential Revision: [D37523159](https://our.internmc.facebook.com/intern/diff/D37523159)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80531
Approved by: https://github.com/SherlockNoMad
2022-07-15 00:58:43 +00:00
jjsjann123
cc67a92e74 fixing call_module on subscripting into generator (#81258)
named_modules() return a generator, which is not subscriptable and causes node support query to fail
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81258
Approved by: https://github.com/SherlockNoMad
2022-07-14 16:41:18 +00:00
Angela Yi
614779f975 [fx] PassResult (#81366)
Passes should now return a `PassResult` which (for now) contain the following fields:
* `graph_module`: The graph module modified during the pass
* `modified`: A flag for if the graph module has been modified
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81366
Approved by: https://github.com/SherlockNoMad
2022-07-13 02:03:11 +00:00
Sherlock Huang
6b280e880a Update NvFuserOperatorSupport (#81311)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81311
Approved by: https://github.com/davidberard98
2022-07-12 21:19:37 +00:00
Sherlock Huang
fc10a63727 Prims+NvFuser Backend Prototype (#80591)
This PR integrates FX graph partitioner + Aten2Prims DecompositionInterpreter + Prims' TraceExecutor + naive caches for nvFuser.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80591
Approved by: https://github.com/jjsjann123, https://github.com/ezyang
2022-07-08 19:53:03 +00:00
anjali411
4bf076e964 Add __all__ to torch.distributed, futures, fx, nn, package, benchmark submodules (#80520)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80520
Approved by: https://github.com/rohan-varma
2022-07-08 14:31:24 +00:00
Drazen Borkovic
9402219a36 Move serialize_module() out of OSS graph_manipulation.py to internal (#80785)
Differential Revision: D37582495

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80785
Approved by: https://github.com/jfix71
2022-07-05 23:39:13 +00:00
Riley Dulin
d579838eb5 [torch][fx] Add ignore_parameters_and_buffers kwarg to FxGraphDrawer (#79982)
Summary:
Add an `ignore_parameters_and_buffers` parameter which will tell the graph drawer
to leave off adding parameter and buffer nodes in the dot graph.

This is useful for large networks, where we want to view the graph to get an idea of
the topology and the shapes without needing to see every detail. Removing these buffers
de-clutters the graph significantly without detracting much information.

Reviewed By: jfix71

Differential Revision: D37317917

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79982
Approved by: https://github.com/jfix71
2022-06-29 22:48:43 +00:00
Sherlock Huang
ac5a94789f Refactor lift_subgraph_as_module as a fx.passes.util function (#80292)
lift_subgraph_as_module can be shared between fuser_utils.py and spliter_utils.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80292
Approved by: https://github.com/jjsjann123, https://github.com/842974287
2022-06-29 22:35:39 +00:00
PyTorch MergeBot
58532256e9 Revert "Add __all__ for torch.distributed and fx modules (#80460)"
This reverts commit 5d40c3d5c8.

Reverted https://github.com/pytorch/pytorch/pull/80460 on behalf of https://github.com/malfet due to Broke MacOS testing, see https://github.com/pytorch/pytorch/runs/7105579664?check_suite_focus=true
2022-06-29 16:20:55 +00:00