Commit Graph

152 Commits

Author SHA1 Message Date
Jason Ansel
edaff88f69 [fx] Move map_aggregate to C++ (#148243)
Microbenchmarking `fx.symbolic_trace(lambda x: functools.reduce(operator.add, [x, *range(100000)]))`, before:
```
30603618 function calls (29403419 primitive calls) in 13.744 seconds
```
after:
```
25203549 function calls (24403352 primitive calls) in 12.090 seconds
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148243
Approved by: https://github.com/oulgen
2025-03-02 22:42:31 +00:00
Xuehai Pan
cba14212e6 [FX] micro-optimization map_aggregate(immutable_dict) (#147691)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147691
Approved by: https://github.com/Skylion007, https://github.com/jansel
ghstack dependencies: #147699, #144640
2025-02-24 09:14:08 +00:00
Simon Fan
ac88a6c00d [fx] demote node prepend to self log from warning to debug (#147538)
FIXES https://github.com/pytorch/pytorch/issues/147175

This is harmless, not sure why this is a user warning. Writing reordering graph passes is more concise when we ignore this warning.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147538
Approved by: https://github.com/yanboliang
2025-02-21 01:32:34 +00:00
Tom Ritchford
272ead7b5e Make fx.node.map_arg() and .map_aggregate() generic (#146248)
## What's the problem?

The popular `fx.node.map_arg()` and `fx.node.map_aggregate()` apply operations recursively on `dict`s, `tuples`, `list`s, etc, and return a new collection of the same type.

Unfortunately, their base input type is `Argument`, which is [very unspecific indeed](5d55a6585d/torch/fx/node.py (L48-L58)): most type information is just thrown away at the call site of either of these functions, as far as the type checker goes.

As `torch` moves to a more typed code base, this would force innocent, unsuspecting developers to add logically unnecessary casts or `# type: ignore` statements.

## What's the solution?

Making these two `node.map_*` functions generic on the first argument and return type means that type information is preserved for the type checker. (The signature of the other parameter, the function that visits the nodes and subnodes, has not changed, nor should it.)

## Won't it break everything?

It doesn't break the type checker - one place needed an extra hint.

There have been code breakages, resolved one, at least one new one... we'll see!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146248
Approved by: https://github.com/XuehaiPan, https://github.com/Skylion007
2025-02-14 19:25:32 +00:00
Aaron Orenstein
1f8ff94d4f PEP585: Add noqa to necessary tests (#146391)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146391
Approved by: https://github.com/justinchuby, https://github.com/Skylion007
2025-02-12 15:29:50 +00:00
Aaron Orenstein
57d8278ab9 pickler for GraphModule (#141659)
Pickling GraphModule needs some special handling for wrapping things that normally can't be pickled - but async compile needs to pass them across a wire so we need to be able to serialize it - add some helpers to enable that.

Differential Revision: [D68921318](https://our.internmc.facebook.com/intern/diff/D68921318)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141659
Approved by: https://github.com/jamesjwu
2025-01-31 05:34:28 +00:00
Yidi Wu
d1143c4b37 [export] fix non-strict pre_dispatch exporting while_loop (#145762)
fix https://github.com/pytorch/pytorch/issues/145737.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145762
Approved by: https://github.com/tugsbayasgalan, https://github.com/zou3519, https://github.com/avikchaudhuri
2025-01-30 18:58:34 +00:00
PyTorch MergeBot
2de53b3b65 Revert "pickler for GraphModule (#141659)"
This reverts commit c6ad08357b.

Reverted https://github.com/pytorch/pytorch/pull/141659 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally, please take a look at D68694181 for more details. ([comment](https://github.com/pytorch/pytorch/pull/141659#issuecomment-2617045120))
2025-01-27 22:39:30 +00:00
leslie-fang-intel
2e80093306 setitem node shouldn't be deadcode eliminated (#145714)
**Summary**
Fix issue https://github.com/pytorch/pytorch/issues/145697. The `operator.setitem` has been eliminated as dead code, causing a correctness issue. Mark it as impure in this PR to avoid this side effect.

**TestPlan**
```
python -u -m pytest -s -v test/fx/test_dce_pass.py -k test_keep_setitem
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145714
Approved by: https://github.com/ezyang
2025-01-27 15:08:21 +00:00
Aaron Orenstein
c6ad08357b pickler for GraphModule (#141659)
Pickling GraphModule needs some special handling for wrapping things that normally can't be pickled - but async compile needs to pass them across a wire so we need to be able to serialize it - add some helpers to enable that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141659
Approved by: https://github.com/jamesjwu
2025-01-26 19:29:13 +00:00
Simon Fan
27598cd154 [fx] move DCE rand check to import time (#145118)
Mitigates the deterministic benchmark regression: https://github.com/pytorch/pytorch/issues/144775#issuecomment-2593411844. and maybe the dashboard issue.

fx.Node.is_impure is unexpectedly a hot spot. It gets called for every node in the graph whenever we invoke DCE, which should be okay, EXCEPT we invoke DCE on the full graph ~10 times at various stages of torch.compile, and an insane number of times (>O(parameters)) for the subgraphs traced by the pattern matcher.

I considered addressing this problem by reducing the amount of times DCE is called, but I think we can only trim the ones from the pattern matcher, which will require some refactor/caching solution that I leave out of this PR.

torch.Tag.nondeterministic_seeded is provided by native_functions.yml and is implemented as a list. Most of the time, it has <=2 elements, so it's not really worth it to turn it into a set for fast lookup.

Using the deterministic instruction count benchmarks
```python
# before
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8914894946
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8866669058
# after
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8770562314
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8779547794
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145118
Approved by: https://github.com/ezyang, https://github.com/zou3519
2025-01-22 02:23:02 +00:00
Aaron Orenstein
0b2a3687b9 PEP585 update - torch/fx (#145166)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145166
Approved by: https://github.com/bobrenjc93
2025-01-20 18:11:54 +00:00
Simon Fan
7f1946aa9b [aot] don't dce aten rng nodes (#144319)
FIXES https://github.com/pytorch/pytorch/issues/143431

For aot_eager backend, we dce twice in aot. The first dce errs on the side of caution and provides a restrictive dce function: 2e1ea8598f/torch/fx/experimental/proxy_tensor.py (L1173)

The second one is more aggressive: 2e1ea8598f/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py (L185)
But this deviates from eager accuracy when rand ops are dce'd

The repro doesn't work for inductor, but that's a separate issue

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144319
Approved by: https://github.com/jansel
2025-01-09 05:27:49 +00:00
Fabian Keller
8cb68b136f Proper modeling of recursive types (#142300)
Currently there are a few type annotations that falsely state that mypy doesn't support recursive types.

Recursive type support is available in mypy for a few years already. It has been officially enabled in [version 0.991](https://mypy-lang.blogspot.com/2022/11/mypy-0990-released.html). Pyright even had support for recursive types earlier (https://github.com/microsoft/pyright/issues/569), so there is probably no reason not to model these types correctly.

This PR models these types properly now. Since this has turned a few implicit `Any` into fully typed variables that are not narrowed cleanly, a small number of type ignores were necessary.

Note that regarding the `Argument` it is desirable to model it in a covariant way (i.e. using `Sequence` and `Mapping`) instead of making it invariant unnecessarily (using `List` and `Dict`). If it were modeled invariant, it would for instance mean that a `List[Node]` would not type check as `Argument`, because invariance would mean that it really has to be a `List[Argument]` (i.e., including all the branches of the union type). Since even the name of the type "argument" strongly suggest that it is semantically used as "argument", having covariance natural anyway.

There are no chances in this PR that affect runtime behavior.

CC @Skylion007

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142300
Approved by: https://github.com/ezyang, https://github.com/Skylion007
2024-12-07 21:30:45 +00:00
Shangdi Yu
51cbac4e6a [export] Change fx graph _replace_hook to a list of Callable (#142006)
Summary: Change fx graph module's _replace_hook from a single hook, to a list of hooks. This is to prepare to registering more hooks for inductor provenance tracking, where we might need to register multiple hooks for node replacement.

Test Plan:
```
buck run mode/dev-nosan caffe2/test:fx -- -r test_hooks_for_node_update
buck run mode/dev-nosan caffe2/test:test_export -- -r test_replace_hook
```

Differential Revision: D66726724

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142006
Approved by: https://github.com/zhxchen17
2024-12-05 03:26:48 +00:00
angelayi
0fbc0830ba [export] Add device and dtype fields to assert_tensor_metadata (#141071)
Differential Revision: [D66321128](https://our.internmc.facebook.com/intern/diff/D66321128)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141071
Approved by: https://github.com/yushangdi, https://github.com/zou3519
2024-11-22 20:54:55 +00:00
Xuehai Pan
abbd71d29d [BE][Easy] enable PYFMT for torch.fx (#138443)
Reproduce command:

```bash
ghstack checkout https://github.com/pytorch/pytorch/pull/138443
git checkout HEAD~1 torch/
lintrunner -a --take "PYFMT" --all-files
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138443
Approved by: https://github.com/ezyang
2024-10-21 19:15:49 +00:00
Jason Ansel
28330a8a39 [reland 1/3][fx] Bypass custom __setattr__ in Node.__init__ (#135733)
Relands #135079 whcih was reverted by #135562

I broke this up into three parts to test internally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135733
Approved by: https://github.com/oulgen
2024-09-12 04:29:37 +00:00
Ivan Zaitsev
440f8f57af Revert "[fx] Bypass custom __setattr__ in Node.__init__ (#135079)" (#135562)
This reverts commit 66da3b3b2a.

#135079 breaks internal tests and needs to be reverted. Revert with mergebot doesn't work as this PR is technically part of the stack, but, according to @jansel, it should be possible to revert it individually.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135562
Approved by: https://github.com/jansel, https://github.com/seemethere
2024-09-10 18:07:11 +00:00
Jason Ansel
66da3b3b2a [fx] Bypass custom __setattr__ in Node.__init__ (#135079)
Before:
![image](https://github.com/user-attachments/assets/5f0a6ae6-6049-44d0-b5f2-a549a23ad97f)

After:
![image](https://github.com/user-attachments/assets/51c9f91b-f8a0-4043-8362-65813feec823)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135079
Approved by: https://github.com/oulgen
ghstack dependencies: #135070, #135076, #135082, #135084
2024-09-06 06:11:46 +00:00
Jason Ansel
bdfc8d9f96 [fx] Don't use generators in map_aggregate (#135082)
While the generators avoid a copy, they are slow.

Before:
![image](https://github.com/user-attachments/assets/70a55a9a-0595-4105-b0ab-22cf77c7409c)

After:
![image](https://github.com/user-attachments/assets/cecb9c59-ae36-47de-8b08-cab2c7cb3d57)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135082
Approved by: https://github.com/oulgen
ghstack dependencies: #135070, #135076
2024-09-05 23:41:30 +00:00
Jason Ansel
70779dded8 [fx] Compile time optimization in Node.__update_args_kwargs (#135076)
Before this we took two passes over all of the args.

Before:
![image](https://github.com/user-attachments/assets/24ce5628-03f4-4983-9f2d-5ddf0ca5816e)

After:
![image](https://github.com/user-attachments/assets/c9681aa2-32f0-4f6b-a598-fc6f90ffafb5)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135076
Approved by: https://github.com/Chillee
ghstack dependencies: #135070
2024-09-05 23:41:30 +00:00
Aaron Orenstein
ed86ac2f25 [BE] typing for decorators - fx/_compatibility (#134054)
Summary: See #131429

Test Plan: unit tests pass

Differential Revision: D61493706

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134054
Approved by: https://github.com/oulgen
2024-08-26 04:00:27 +00:00
Aaron Orenstein
d95aedf5fd [BE] typing for decorators - fx/_compatibility (part 1) (#134202)
Part of #134054.

This corresponds to the pytorch mypy changes from D61493706. Updating takes so
long and touches so many files that it's impossible to land as a whole without conflicting with some other intermediate change.
So landing these 'type: ignore' for pytorch in advance of them actually being needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134202
Approved by: https://github.com/Skylion007
2024-08-22 17:07:33 +00:00
Shangdi Yu
825002c9c6 [export][fx] More robust DCE pass (#132764)
Summary:
- make default DCE pass check schema,
- need to rebase onto https://github.com/pytorch/pytorch/pull/131651 after it's in phabricator (for now the change is manually added).

- mark Proxy dump as NotImplemented for better error msg

- Remove Proxy from tensors when dumping models, as Proxy cannot be dumped.

More details in https://docs.google.com/document/d/1G5vmTXjzxoyVGRI2kpA1gQukK_Glyg2NrE0Oh6Nlg9A/edit?usp=sharing.

Test Plan:
CI
```
- buck2 run 'fbcode//mode/dev-nosan'  fbcode//caffe2/test/quantization:test_quantization -- -r  qat_conv2d
- test_export.py
- buck2 run 'fbcode//mode/dev-nosan' fbcode//modai/test:test_modai -- -r test_qat_stinson_htp_export
- buck2 run 'fbcode//mode/dev-nosan' fbcode//vizard_projects/ml_depth/tests:test_model -- -r test_qat_model_et
- buck2 run 'fbcode//mode/dev-nosan'  fbcode//caffe2/test:fx -- -r dce
- buck2 run 'fbcode//mode/dev-nosan' fbcode//bolt/nn/executorch/backends/tests:qnn_test -- -r test_qat_bias=False,use_3d_input=False
- buck2 run 'fbcode//mode/dev-nosan' fbcode//bolt/nn/executorch/backends/tests:qnn_test -- -r test_qat_bias=True,use_3d_input=False
- buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r  test_fold_bn_erases_bn_node
```

Reviewed By: angelayi

Differential Revision: D60319175

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132764
Approved by: https://github.com/angelayi
2024-08-06 22:27:22 +00:00
PyTorch MergeBot
945bf78894 Revert "[BE] typing for decorators - fx/_compatibility (#131568)"
This reverts commit 193f62fde9.

Reverted https://github.com/pytorch/pytorch/pull/131568 on behalf of https://github.com/clee2000 due to same as https://github.com/pytorch/pytorch/pull/131572#issuecomment-2254328359 but I clicked the wrong link by accident.  This is where it actually starts ([comment](https://github.com/pytorch/pytorch/pull/131568#issuecomment-2254330781))
2024-07-28 03:43:39 +00:00
Aaron Orenstein
193f62fde9 [BE] typing for decorators - fx/_compatibility (#131568)
See #131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131568
Approved by: https://github.com/justinchuby, https://github.com/oulgen, https://github.com/zou3519
2024-07-25 22:24:19 +00:00
rzou
98984422eb [triton_op] fix autotuning (#131363)
The problem was we were shoving SymInts into the constant_args side
table. The root problem is that torch.fx.node.base_types, which we use
to determine what can be put in the graph, doesn't actually have SymInt
in it. This PR fixes base_types to include SymInt.

Test Plan:
- tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131363
Approved by: https://github.com/oulgen, https://github.com/justinchuby
2024-07-24 14:03:37 +00:00
Aaron Orenstein
5a0068cc69 [BE] mypy: disallow untyped decorators (#131428)
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations.

Step 1 - Enable the error and override in all the offending files.

#131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428
Approved by: https://github.com/justinchuby, https://github.com/oulgen
2024-07-23 21:50:55 +00:00
PyTorch MergeBot
6b8ec2b371 Revert "[triton_op] fix autotuning (#131363)"
This reverts commit 154f27455a.

Reverted https://github.com/pytorch/pytorch/pull/131363 on behalf of https://github.com/ZainRizvi due to This was a tricky one, but looking at the code it's the change to torch/fx/node.py that triggered the type violation errors. Reverting since this is now breaking trunk ([comment](https://github.com/pytorch/pytorch/pull/131363#issuecomment-2245899858))
2024-07-23 18:01:09 +00:00
Oguz Ulgen
4ca8705035 Add mypy typing to fx node (#131434)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131434
Approved by: https://github.com/zou3519
2024-07-23 17:00:31 +00:00
rzou
154f27455a [triton_op] fix autotuning (#131363)
The problem was we were shoving SymInts into the constant_args side
table. The root problem is that torch.fx.node.base_types, which we use
to determine what can be put in the graph, doesn't actually have SymInt
in it. This PR fixes base_types to include SymInt.

Test Plan:
- tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131363
Approved by: https://github.com/oulgen
2024-07-23 16:15:00 +00:00
Shangdi Yu
29e2e2afb6 Revert D59561509: Multisect successfully blamed "D59561509: [FX][export] DCE pass, check schema for node impurity (#130395)" for one test failure (#131341)
Summary:
This diff reverts D59561509
D59561509: [FX][export] DCE pass, check schema for node impurity (#130395) by yushangdi causes the following test failure:

Tests affected:
- [cogwheel:cogwheel_mtia_cmf_m5_shrunk_test#test_flow_with_verification](https://www.internalfb.com/intern/test/844425041436985/)

Here's the Multisect link:
https://www.internalfb.com/multisect/6533402
Here are the tasks that are relevant to this breakage:
T191383430: 10+ tests unhealthy for ads_mtia_inference

The backout may land if someone accepts it.

If this diff has been generated in error, you can Commandeer and Abandon it.

Test Plan: NA

Differential Revision: D60029318

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131341
Approved by: https://github.com/angelayi
2024-07-23 05:23:47 +00:00
Shangdi Yu
27ded03545 [FX][export] DCE pass, check schema for node impurity (#130395)
Change the default DCE pass to check node schema for impure nodes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130395
Approved by: https://github.com/angelayi, https://github.com/jgong5
2024-07-18 16:31:40 +00:00
PyTorch MergeBot
433ef4e444 Revert "[FX][export] DCE pass, check schema for node impurity (#130395)"
This reverts commit e22b0acc76.

Reverted https://github.com/pytorch/pytorch/pull/130395 on behalf of https://github.com/yushangdi due to breaking tests, need to rebase and fix ([comment](https://github.com/pytorch/pytorch/pull/130395#issuecomment-2235192986))
2024-07-18 02:46:03 +00:00
Shangdi Yu
e22b0acc76 [FX][export] DCE pass, check schema for node impurity (#130395)
Change the default DCE pass to check node schema for impure nodes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130395
Approved by: https://github.com/angelayi, https://github.com/jgong5
2024-07-18 00:55:20 +00:00
Brian Hirsh
a4d7aa498b [Traceable FSDP2] Add auto-functionalize support for mutable list[Tensor] (copy from Brian's PR #127347); enable E2E inductor unit test for transformer model (#129502)
Copy of Brian's PR: https://github.com/pytorch/pytorch/pull/127347 with additional changes to support mutable `List[Tensor]` in Inductor. Also enable E2E inductor unit test for Traceable FSDP2 + transformer model.

Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_trace_fsdp_set_`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_simple_mlp_fullgraph_backend_aot_eager`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_simple_mlp_fullgraph_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_fullgraph_backend_aot_eager`
- `pytest -rA test/dynamo/test_misc.py::MiscTests::test_auto_functionalize_tensorlist`
- `pytest -rA  test/inductor/test_torchinductor.py::GPUTests::test_fallback_mutable_op_list_cuda`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129502
Approved by: https://github.com/zou3519
2024-06-27 17:50:57 +00:00
PyTorch MergeBot
45b2931b7e Revert "[Traceable FSDP2] Don't decompose fsdp.split_with_sizes_copy (#129414)"
This reverts commit b24787b757.

Reverted https://github.com/pytorch/pytorch/pull/129414 on behalf of https://github.com/ZainRizvi due to This PR is seems to be causing multiple macos failures.  Looks like it was merged before trunk jobs were started, which would have run those tests ([comment](https://github.com/pytorch/pytorch/pull/129414#issuecomment-2189479505))
2024-06-25 17:05:55 +00:00
Will Feng
b24787b757 [Traceable FSDP2] Don't decompose fsdp.split_with_sizes_copy (#129414)
This makes it easier to do pattern-matching on `fsdp.split_with_sizes_copy` in Inductor passes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129414
Approved by: https://github.com/bdhirsh
2024-06-25 03:08:56 +00:00
Brian Hirsh
b91a9dc328 [Brian's PR #128754] Use torch.ops.fsdp.set_ for FSDP2 storage resize; dont functionalize resize_, set_, split_with_sizes_copy.out (#129203)
This is a copy of Brian's PR https://github.com/pytorch/pytorch/pull/128754, with some changes in the test_distributed_patterns.py unit tests to more closely reflect FSDP2 patterns. Also disabled two tests `test_input_mutation_storage_resize_up_down` and `test_input_mutation_storage_resize_not_supported` in test_aotdispatch.py until we figure out the right behavior for them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129203
Approved by: https://github.com/bdhirsh
2024-06-23 06:07:19 +00:00
Will Feng
e165a5971f [Traceable FSDP2] Fix support for CUDA resize_storage_bytes_ (#129215)
Currently if `x` is a CUDA tensor, calling `x.untyped_storage().resize_()` seems to always go into the `built without cuda` branch of `resize_storage_bytes_()` regardless of whether PyTorch is built with CUDA. I suspect this is because `inductor_ops.cpp` is only included in `libtorch_cpu.so` thus doesn't have the `USE_CUDA` information or ability to link to CUDA-related functions.

This PR moves `resize_storage_bytes_()` related custom op functions out of `inductor_ops.cpp` into its standalone file `resize_storage_bytes.cpp` to be included in `libtorch_python.so` instead. This mimics the setup for `StorageMethods.cpp`. This way, `resize_storage_bytes_()` can have access to the CUDA-related functions, which passes the CUDA unit test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129215
Approved by: https://github.com/jansel
2024-06-22 18:38:47 +00:00
Oguz Ulgen
5b5d269d34 Speed up fx graph iteration by implementing it in C++ (#128288)
Before this change
```
python benchmarks/dynamo/microbenchmarks/fx_microbenchmarks.py
iterating over 100000000 FX nodes took 19.5s (5132266 nodes/s)
```

After this change
```
python benchmarks/dynamo/microbenchmarks/fx_microbenchmarks.py
iterating over 100000000 FX nodes took 3.4s (29114001 nodes/s)
```

5.7x improvement

Differential Revision: [D58343997](https://our.internmc.facebook.com/intern/diff/D58343997)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128288
Approved by: https://github.com/jansel, https://github.com/albanD
2024-06-11 05:48:31 +00:00
Aaron Orenstein
7a60a75256 Add typing annotations to pattern_matcher.py (#127458)
Turn on `mypy: disallow-untyped-defs` in pattern_matcher.py and fix the fallout.

There are still a bunch of `type: ignore` annotations which should eventually be ironed out.

In the processs found a bug: #127457

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127458
Approved by: https://github.com/Skylion007
ghstack dependencies: #127457
2024-06-04 15:24:47 +00:00
Brian Hirsh
9c3b87833a AOTAutograd: keep set_() input mutations in the graph, ban other cases (#122981)
We have some (limited) support for `set_()` input mutations in `torch.compile`, but one restriction today is that we force them to run outside of the graph, in the opaque runtime epilogue.

This is a problem for ppFSDP. Why? The usage pattern of ppFSDP forward graphs look something like this:
```
def forward_fsdp(sacrificial_param, sharded_param, inp):
    allgathered_param = allgather(sharded_param)
    sacrificial_param.set_(allgathered_param)  # hidden in an autograd.Function that we trace
    out = matmul(sacrificial_param, inp)
    sacrificial_param.untyped_storage().resize_(0)
    return out
```
When we functionalize this graph, `sacrificial_param` sees two distinct types of input mutations, that we must preserve: a `set_`, and a `resize_`. Importantly, at runtime the `set_()` must run **before** the `resize_()`. Why? the `set_()` updates the storage of our sacrificial param to the allgather'd data, which allows the call to `sacrificial_param.resize_()` to free the allgathered data later. If we run the two mutations in reverse order, we will never free the allgathered data.

We want to put the `resize_()` mutation op inside of the graph (see next PR, also there's a much longer description in that PR for anyone interested). However, this will require us to put `set_()` in the graph as well, in order for them to run in the correct order.

In order to do this, I had to add some extra restrictions: You are now required to run `set_()` under `no_grad()` if you use it with `torch.compile`, and if you perform any other mutations to the input, those must be under no_grad as well (otherwise, the mutations may mutate the `grad_fn` of the input, making it no longer safe to keep in the graph). These restrictions are hopefully reasonable, since `set_()` doesn't see much usage today (and the original impetus for adding set_() support a few months ago was for fsdp anyway)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122981
Approved by: https://github.com/jansel
ghstack dependencies: #122433, #123646
2024-04-11 18:21:57 +00:00
Oguz Ulgen
526a69f5ee Remove incorrect check (#123616)
Summary: This was a micro optimization that I thought would save time but it is not correct. For example, we cannot compare fake tensors.

Test Plan:
```
buck2 run 'fbcode//mode/opt' fbcode//langtech/edge/ns/tools/tests:test_ns_jit_traced_model_all_optimization_f328819347_portal_ns
```
now passes

Differential Revision: D55904083

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123616
Approved by: https://github.com/aakhundov
2024-04-09 08:45:34 +00:00
Oguz Ulgen
03b13851d9 [FX] Add side table to FX Graph for O(1) op/target query (#121565)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121565
Approved by: https://github.com/jansel
2024-04-07 18:51:05 +00:00
Jacob Szwejbka
41d24df08f [export] hack skip index_put_ in dce (#122683)
Summary: Ideally we should do whats in the todo. Just doing this for now to unblock llama capture

Test Plan: capturing llama and using pt2e to quantize it

Differential Revision: D55354487

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122683
Approved by: https://github.com/kimishpatel
2024-03-26 08:05:06 +00:00
Jason Ansel
18d94d7165 Make FX nodes sortable (#122071)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122071
Approved by: https://github.com/oulgen
2024-03-19 01:40:56 +00:00
Jason Ansel
75a6d6aef7 [inductor] Support storage resizing (#119749)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119749
Approved by: https://github.com/yf225
ghstack dependencies: #119647, #119671
2024-02-14 03:03:38 +00:00
Edward Z. Yang
9bce208dfb Replace follow_imports = silent with normal (#118414)
This is a lot of files changed! Don't panic! Here's how it works:

* Previously, we set `follow_imports = silent` for our mypy.ini configuration. Per https://mypy.readthedocs.io/en/stable/running_mypy.html#follow-imports, what this does is whenever we have an import to a module which is not listed as a file to be typechecked in mypy, we typecheck it as normal but suppress all errors that occurred in that file.
* When mypy is run inside lintrunner, the list of files is precisely the files covered by the glob in lintrunner.toml, but with files in excludes excluded.
* The top-level directive `# mypy: ignore-errors` instructs mypy to typecheck the file as normal, but ignore all errors.
* Therefore, it should be equivalent to set `follow_imports = normal`, if we put `# mypy: ignore-errors` on all files that were previously excluded from the file list.
* Having done this, we can remove the exclude list from .lintrunner.toml, since excluding a file from typechecking is baked into the files themselves.
* torch/_dynamo and torch/_inductor were previously in the exclude list, because they were covered by MYPYINDUCTOR. It is not OK to mark these as `# mypy: ignore-errors` as this will impede typechecking on the alternate configuration. So they are temporarily being checked twice, but I am suppressing the errors in these files as the configurations are not quite the same. I plan to unify the configurations so this is only a temporary state.
* There were some straggler type errors after these changes somehow, so I fixed them as needed. There weren't that many.

In the future, to start type checking a file, just remove the ignore-errors directive from the top of the file.

The codemod was done with this script authored by GPT-4:

```
import glob

exclude_patterns = [
    ...
]

for pattern in exclude_patterns:
    for filepath in glob.glob(pattern, recursive=True):
        if filepath.endswith('.py'):
            with open(filepath, 'r+') as f:
                content = f.read()
                f.seek(0, 0)
                f.write('# mypy: ignore-errors\n\n' + content)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118414
Approved by: https://github.com/thiagocrepaldi, https://github.com/albanD
2024-01-27 02:44:11 +00:00