Commit Graph

178 Commits

Author SHA1 Message Date
Aaron Orenstein
5a0068cc69 [BE] mypy: disallow untyped decorators (#131428)
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations.

Step 1 - Enable the error and override in all the offending files.

#131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428
Approved by: https://github.com/justinchuby, https://github.com/oulgen
2024-07-23 21:50:55 +00:00
Shangdi Yu
ea4b80e6d6 [FX][export] strict DCE pass, check schema for node impurity (#130552)
Fixes the failure in `test/export/test_export_training_ir_to_run_decomp.py ` caused by dead code elimination removing node with side effects.

For background, in export, we may want to export higher-level IRs that are not functional, so we need to check for side effects more carefully.

 A call_function node is impure if it has at least one mutable argument.

Fixed the tests below:

test_to_module_with_mutated_buffer_multiple_update_sub_later
test_export_input_mutation_static_shape
test_buffer_util

Another attempt modifying the original DCE pass is made in PR #130395, but it breaks some other tests, so here we add a flag and use it for export only.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130552
Approved by: https://github.com/pianpwk
2024-07-12 15:43:27 +00:00
Chen Lai
721a798886 add bits16 to graph dtype_abbrs (#130339)
As title, patch the dtype in torch.fx.graph
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130339
Approved by: https://github.com/angelayi
2024-07-09 19:58:51 +00:00
angelayi
e9c6e8369c Torchbind call method + effects support (#128397)
Adds effect token support to torchbind method calls by allowing `with_effects` to take in `torch.ops._higher_order_ops.call_torchbind` as an input.

Here is the print from `TORCH_LOGS="aot" python test/export/test_torchbind.py -k test_compile_obj_torchbind_op`:
```python
def forward(self, arg0_1: "f32[0]", arg1_1: "f32[2]", arg2_1):
    # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1266 in f, code: torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
    cos: "f32[2]" = torch.ops.aten.cos.default(arg1_1)
    with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops._TorchScriptTesting.queue_push.default, arg2_1, cos);  arg0_1 = cos = None
    getitem: "f32[0]" = with_effects[0];  with_effects = None

    # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1267 in f, code: torch.ops._TorchScriptTesting.queue_push(tq, x.cos() + 1)
    cos_1: "f32[2]" = torch.ops.aten.cos.default(arg1_1)
    add: "f32[2]" = torch.ops.aten.add.Tensor(cos_1, 1);  cos_1 = None
    with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.queue_push.default, arg2_1, add);  getitem = add = None
    getitem_2: "f32[0]" = with_effects_1[0];  with_effects_1 = None

    # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1268 in f, code: torch.ops._TorchScriptTesting.queue_pop(tq)
    with_effects_2 = torch._higher_order_ops.effects.with_effects(getitem_2, torch.ops._TorchScriptTesting.queue_pop.default, arg2_1);  getitem_2 = None
    getitem_4: "f32[0]" = with_effects_2[0];  with_effects_2 = None

    # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1269 in f, code: torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
    sin: "f32[2]" = torch.ops.aten.sin.default(arg1_1);  arg1_1 = None
    with_effects_3 = torch._higher_order_ops.effects.with_effects(getitem_4, torch.ops._TorchScriptTesting.queue_push.default, arg2_1, sin);  getitem_4 = sin = None
    getitem_6: "f32[0]" = with_effects_3[0];  with_effects_3 = None

    # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1270 in f, code: return tq.pop(), tq.pop() + tq.size(), tq
    with_effects_4 = torch._higher_order_ops.effects.with_effects(getitem_6, torch.ops._higher_order_ops.call_torchbind, arg2_1, 'pop');  getitem_6 = None
    getitem_8: "f32[0]" = with_effects_4[0]
    getitem_9: "f32[2]" = with_effects_4[1];  with_effects_4 = None
    with_effects_5 = torch._higher_order_ops.effects.with_effects(getitem_8, torch.ops._higher_order_ops.call_torchbind, arg2_1, 'pop');  getitem_8 = None
    getitem_10: "f32[0]" = with_effects_5[0]
    getitem_11: "f32[2]" = with_effects_5[1];  with_effects_5 = None
    with_effects_6 = torch._higher_order_ops.effects.with_effects(getitem_10, torch.ops._higher_order_ops.call_torchbind, arg2_1, 'size');  getitem_10 = arg2_1 = None
    getitem_12: "f32[0]" = with_effects_6[0];  with_effects_6 = None
    add_1: "f32[2]" = torch.ops.aten.add.Tensor(getitem_11, 0);  getitem_11 = None
    return (getitem_12, getitem_9, add_1)
```

In order to support this, this PR makes the following changes:
* Adds `FakeScriptObject` to `CustomObjArgument`, which will be put on the `meta["val"]` of nodes representing torchbind objects.
* Adds pickle/deepcopy support to FunctionSchema.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128397
Approved by: https://github.com/ydwu4, https://github.com/zou3519
2024-06-14 21:28:17 +00:00
chilli
c486e2ab64 Add coloring to fx graph print out (#128476)
Note: Won't land immediately, at least I'll need to add a color option to the field. But curious if any tests fail.

Old:
<img width="1294" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/c3a750ed-5e54-4621-b2e4-be5481be15b6">

New:
<img width="1303" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/3a1f1adc-6f3a-413e-8b87-ee53da9bf4ed">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128476
Approved by: https://github.com/ezyang
2024-06-13 23:39:04 +00:00
Oguz Ulgen
5b5d269d34 Speed up fx graph iteration by implementing it in C++ (#128288)
Before this change
```
python benchmarks/dynamo/microbenchmarks/fx_microbenchmarks.py
iterating over 100000000 FX nodes took 19.5s (5132266 nodes/s)
```

After this change
```
python benchmarks/dynamo/microbenchmarks/fx_microbenchmarks.py
iterating over 100000000 FX nodes took 3.4s (29114001 nodes/s)
```

5.7x improvement

Differential Revision: [D58343997](https://our.internmc.facebook.com/intern/diff/D58343997)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128288
Approved by: https://github.com/jansel, https://github.com/albanD
2024-06-11 05:48:31 +00:00
Aaron Orenstein
038b927590 Flip default value for mypy disallow_untyped_defs [7/11] (#127844)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127844
Approved by: https://github.com/oulgen
ghstack dependencies: #127842, #127843
2024-06-08 18:49:45 +00:00
Xuehai Pan
8b08b0f340 [BE] enable ruff rule Q from flake8-quotes (#127713)
Enable [ruff rule `Q`](https://docs.astral.sh/ruff/rules/#flake8-quotes-q) from flake8-quotes. Fixes:

- [avoidable-escaped-quote (Q003)](https://docs.astral.sh/ruff/rules/avoidable-escaped-quote/#avoidable-escaped-quote-q003)
- [unnecessary-escaped-quote (Q004)](https://docs.astral.sh/ruff/rules/unnecessary-escaped-quote/#unnecessary-escaped-quote-q004)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127713
Approved by: https://github.com/ezyang
2024-06-02 23:25:26 +00:00
Sheng Fu
bbeb0906c4 Register creak_node_hook (#126671)
Differential Revision: D57469157

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126671
Approved by: https://github.com/angelayi
2024-05-24 23:32:15 +00:00
angelayi
8be4c1bc2f [export] Add metadata for nodes insert_deferred_runtime_asserts (#125414)
Fixes [internal error](https://fb.workplace.com/groups/1075192433118967/permalink/1416709435633930/).

The issue is that the asserting nodes added in the `insert_deferred_runtime_assertion` pass do not contain metadata that the ExportedProgram requires the graph to have. One solution to fix this is to retrace the entire module, or another solution is to manually add back this metadata.

This diff implements the latter solution (manually add back the metadata) through hooking into fx.graph's `create_node` function, and adding export-specific metadata for every node that is created. The reason I did this is so that the `insert_deferred_runtime_assertion` does not have to know about what metadata export wants.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125414
Approved by: https://github.com/zhxchen17, https://github.com/BoyuanFeng
2024-05-07 23:15:21 +00:00
Edward Z. Yang
ecd62746e3 Also pull size/stride info from example_value (#125505)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125505
Approved by: https://github.com/jansel
2024-05-05 22:27:46 +00:00
Simon Fan
43a7ab2a21 [compiled autograd] introduce verbose logs, add autograd node info to graph (#124954)
- sets it as a fake stack trace as we don't have a generic comment feature
- when verbose is disabled, still adds a contextmanager and flag checks. the alternative is to use MACROS, but that wouldn't be usable with TORCH_LOGS

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124954
Approved by: https://github.com/jansel
2024-04-27 01:10:37 +00:00
Sherlock Huang
c2f687f32c Option to include stride and device annotation in gm.print_readable() (#123690)
Summary:
Sample output for gm.print_readable(include_stride=True, include_device=True)

```
        getitem_21: "i32[1200][1]cuda:0" = auto_functionalized_4[1]
        copy_2: "f32[2, 60][60, 1]cuda:1"  = ....
```

Test Plan: CI

Differential Revision: D55949129

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123690
Approved by: https://github.com/Chillee
2024-04-11 06:53:10 +00:00
Oguz Ulgen
03b13851d9 [FX] Add side table to FX Graph for O(1) op/target query (#121565)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121565
Approved by: https://github.com/jansel
2024-04-07 18:51:05 +00:00
Oguz Ulgen
7c5e29ae71 Back out "Support triton.language.dtype with torch.compile (#121690)" (#122108)
Summary: Some hard to deal with package import/export related problems. Lets revert and start with clean slate.

Test Plan: CI

Differential Revision: D55024877

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122108
Approved by: https://github.com/ezyang
2024-03-18 20:50:28 +00:00
Oguz Ulgen
65ccac6f17 Fix triton import time cycles (#122059)
Summary: `has_triton` causes some import time cycles. Lets use `has_triton_package` which is enough.

Test Plan:
```
buck2 test 'fbcode//mode/opt' fbcode//fblearner/flow/projects/model_processing/pytorch_model_export_utils/logical_transformations/tests:filter_inference_feature_metadata_test -- --exact 'fblearner/flow/projects/model_processing/pytorch_model_export_utils/logical_transformations/tests:filter_inference_feature_metadata_test - test_collect_features_from_graph_module_nodes (fblearner.flow.projects.model_processing.pytorch_model_export_utils.logical_transformations.tests.filter_inference_feature_metadata_test.FilterInferenceFromFeatureMetadataTest)'
```
now passes

Differential Revision: D55001430

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122059
Approved by: https://github.com/aakhundov
2024-03-18 05:50:32 +00:00
Oguz Ulgen
e39aedfcc5 Fix fx graph triton import bug (#122041)
Summary: Unless we register triton to be a special import, FX graph import mechanism imports it as `from fx-generated._0 import triton as triton` which is obviously broken.

Test Plan:
I could not figure out how to write a test for this but
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//tgif/lib/tests/gpu_tests:lowering_pass_test -- -r test_default_ait_lowering_multi_hardwares
```
now passes

Differential Revision: D54990782

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122041
Approved by: https://github.com/aakhundov
2024-03-17 22:48:51 +00:00
Oguz Ulgen
79ee6bbde3 Support triton.language.dtype with torch.compile (#121690)
Putting this PR as an RFC since I have resorted to some horrible hacks in order to make this work.
```
(Pdb) p triton.language.float32
triton.language.fp32
(Pdb) p str(triton.language.float32)
'fp32'
(Pdb) p repr(triton.language.float32)
'triton.language.fp32'
```
This means that we need to "rewrite" them for fx graph and inductor execution.

This PR allows Mamba2 to work with `torch.compile`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121690
Approved by: https://github.com/Skylion007
2024-03-12 23:21:46 +00:00
Jason Ansel
9aa3fedb75 Slightly faster FX graph iterator (#121611)
Before:
```
iterating over 100000000 FX nodes took 5.9s (16830686 nodes/s)
```

After:
```
iterating over 100000000 FX nodes took 5.0s (19937698 nodes/s)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121611
Approved by: https://github.com/oulgen
2024-03-11 20:00:19 +00:00
Oguz Ulgen
660ec3d38d [Export] Fix bug removing node from wrong graph (#121574)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121574
Approved by: https://github.com/ydwu4
2024-03-10 04:46:11 +00:00
Jiong Gong
896cf9d1ce [inductor][cpp] vectorization support for int32/int64 (#119001)
This pull request aims to complete most of the support for vectorizing int32 and int64 data types except for indirect indexing and masks. The basic data type support for uint32 and uint64 is also added but without vectorization. More vectorized conversion functions are added between integer and float. In order to support int64 vectors, a new VectorizedN class to handle vectors of arbitrary length. Below are the details:
1. Complete most of the int32 and int64 vectorization support including load, store, reduction, constant and conversion. The indirect indexing and masks will be addressed in follow-up PRs, after which, the legality checking logic in `CppVecKernelChecker` can be further simplified.
2. Util functions for conversion between integer and float vectors (in cpp_prefix.h and ATen vec). Ideally, we'd better move them from cpp_prefix.h to ATen vec to simplify cpp_prefix.h, will be addressed in follow-up PRs.
3. Introduced a new template class VectorizedN, designed to handle vectors of arbitrary length by encapsulating multiple Vectorized<T> instances. This class supports most of the operations of `Vectorized<T>`. It makes the support of int64 vectorization simpler. I will also apply it to bf16/fp16/int8 in the follow-up PRs for better efficiency. For example, bf16 currently only uses half of the vector lanes. With `VectorizedN`, we can use full of the lanes and map bf16 vector to `VectorizedN<float,2>` on conversion.
4. Basic data type support is added for uint32 and uint64 (in graph.py). Vectorization support will be added later but not of high priority due to fewer usages.

Next steps:

- [ ] Refactor the vector mask handling to support data types other than float. Currently vector masks are implemented with float vectors.
- [ ] Fully utilize vector lanes for bfloat16/float16/int8.
- [ ] Support indirect indexing with vectorized index via scalarization.
- [ ] Clean up `CppVecKernelChecker`.
- [ ] Simplify `cpp_prefix.h` including refactoring vector conversion logic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119001
Approved by: https://github.com/peterbell10, https://github.com/jansel
2024-02-08 17:38:49 +00:00
Jeff Daily
01abb5af21 additional support for float8_e4m3fnuz and _e5m2fnuz (#115214)
Follow up to #107586.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115214
Approved by: https://github.com/peterbell10, https://github.com/malfet
2024-01-22 18:33:41 +00:00
PyTorch MergeBot
b637fdc8b3 Revert "additional support for float8_e4m3fnuz and _e5m2fnuz (#115214)"
This reverts commit 74e1362499.

Reverted https://github.com/pytorch/pytorch/pull/115214 on behalf of https://github.com/PaliC due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/115214#issuecomment-1900815152))
2024-01-19 17:35:04 +00:00
Jeff Daily
74e1362499 additional support for float8_e4m3fnuz and _e5m2fnuz (#115214)
Follow up to #107586.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115214
Approved by: https://github.com/peterbell10
2024-01-19 00:50:18 +00:00
Edward Z. Yang
61a181e83c Report function name in stack trace annotations (#117459)
When working with internal flows, it can sometimes be ambiguous what
version of the code they are working with.  In this case, having the
function name available in the stack trace can help identify what you
are looking at.

Example now looks like:

```
[DEBUG]         # File: /data/users/ezyang/a/pytorch/a.py:5 in f, code: return x + x
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117459
Approved by: https://github.com/Skylion007
2024-01-15 00:29:13 +00:00
Aaron Gokaslan
3fe437b24b [BE]: Update flake8 to v6.1.0 and fix lints (#116591)
Updates flake8 to v6.1.0 and fixes a few lints using sed and some ruff tooling.
- Replace `assert(0)` with `raise AssertionError()`
- Remove extraneous parenthesis i.e.
  - `assert(a == b)` -> `assert a == b`
  - `if(x > y or y < z):`->`if x > y or y < z:`
  - And `return('...')` -> `return '...'`

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116591
Approved by: https://github.com/albanD, https://github.com/malfet
2024-01-03 06:04:44 +00:00
Xuehai Pan
199e07f108 [pytree][BE] update treespec num_children access (#116370)
Change `len(treespec.children_spes) -> treespec.num_children`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116370
Approved by: https://github.com/Skylion007
2023-12-24 20:54:32 +00:00
PyTorch MergeBot
684ce1b21d Revert "Assert that output could only be the last node of the FX graph (#115179)"
This reverts commit 4a9fb9832a.

Reverted https://github.com/pytorch/pytorch/pull/115179 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/115179#issuecomment-1845776365))
2023-12-07 17:26:27 +00:00
Oguz Ulgen
4a9fb9832a Assert that output could only be the last node of the FX graph (#115179)
Test Plan: unit tests

Differential Revision: D51856848

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115179
Approved by: https://github.com/Chillee
2023-12-06 08:17:16 +00:00
PyTorch MergeBot
8bb3cd192f Revert "Assert that output could only be the last node of the FX graph (#114973)"
This reverts commit a85df9eb0b.

Reverted https://github.com/pytorch/pytorch/pull/114973 on behalf of https://github.com/atalman due to Diff broke internal tests ([comment](https://github.com/pytorch/pytorch/pull/114973#issuecomment-1839290400))
2023-12-04 19:07:48 +00:00
Oguz Ulgen
a85df9eb0b Assert that output could only be the last node of the FX graph (#114973)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114973
Approved by: https://github.com/Chillee
2023-12-01 23:04:19 +00:00
PyTorch MergeBot
7c8d3639cf Revert "[fx] log the node when it's get eliminated (#112684)"
This reverts commit 6256d3710e.

Reverted https://github.com/pytorch/pytorch/pull/112684 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/112684#issuecomment-1831198778))
2023-11-29 04:31:15 +00:00
Shiyan Deng
6256d3710e [fx] log the node when it's get eliminated (#112684)
Summary: ATT

Test Plan: CI

Reviewed By: strisunshinewentingwang

Differential Revision: D50912413

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112684
Approved by: https://github.com/zyan0
2023-11-29 01:43:04 +00:00
Kaichao You
958f755a0e [FX][CodeGen] Make sure fx code is valid in python (#113345)
This PR fixes two cases when fx generated code is invalid in python (syntax error):

1. multiple type annotation in one line: `var1: annotation1, var2: annotation2 = function_call()`
2. invalid type annotation for scalars like `var1: f32[] = function_call()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113345
Approved by: https://github.com/ezyang
2023-11-10 21:12:16 +00:00
Peter Bell
66c32d099a Use pytree.arg_tree_leaves everywhere (#112394)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112394
Approved by: https://github.com/lezcano
ghstack dependencies: #112391, #112392, #112393
2023-10-31 15:57:06 +00:00
Peter Bell
bbd5b935e4 Use pytree.tree_leaves everywhere (#112324)
This changes all the instances I could find of `tree_flatten(...)[0]` or
`x, _ = tree_flatten` to use `tree_leaves`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112324
Approved by: https://github.com/lezcano
ghstack dependencies: #112327, #112323
2023-10-30 03:39:04 +00:00
Wenting Wang
675df7520a [tgif][multiforward] allow codegen to generate different func name (#111446)
Summary: see Shiyan's design doc for ATM TS publish weights dedupe https://fb.quip.com/HnUVAjUMaXMQ

Test Plan: tested in N4454041 after D50341352 that multiforward method is working for ts model

Differential Revision: D45750812

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111446
Approved by: https://github.com/842974287
2023-10-19 21:19:30 +00:00
willfengg
772e104dfd [inductor] visualize fused ops in svg graph (#107752)
example usage
* `TORCH_COMPILE_DEBUG=1 INDUCTOR_ORIG_FX_SVG=1 INDUCTOR_POST_FUSION_SVG=1 python trig.py`: show original fx node name, file, and code. see snapshot 2 where we have origin_0, 1, 2
* trig.py can be found in P816304818

Implementation
* keep original fx graph in GraphLowering, ```self.orig_gm: torch.fx.GraphModule = gm.__copy__()```
* draw original fx graph with origins ir_post_fusion ```V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes)```. node.meta["buff_meta"] tracks buf_name

<img width="350" alt="Screenshot 2023-08-29 at 12 40 24 PM" src="https://github.com/pytorch/pytorch/assets/134637289/c4e197cb-ab3b-4a09-a584-c1356376accb">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107752
Approved by: https://github.com/mlazos
2023-09-21 08:03:05 +00:00
Yukio Siraichi
6e3a7473cf Trace calls with Python Enum values. (#109507)
Fix: #82135
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109507
Approved by: https://github.com/ezyang
2023-09-20 22:18:11 +00:00
William Wen
b904432e82 [dynamo] preserve some FX node metadata of GraphModules (#107067)
Requested from @tugsbayasgalan: we want dynamo to preserve some FX node metadata when we trace `GraphModule`s (`nn_module_stack`, `source_fn`, `stack_trace`). This is helpful for the case when we export an aten-level `GraphModule`, add some (possibly non-torch or non-aten) ops, and we want to transform the graph back into an aten-level graph. Without preserving metadata, future passes that look at metadata (e.g. quantization passes) won't work.

This feature also has the additional benefit of being able to preserve origin line of code when `print_readable`'ing a `GraphModule`. This is helpful when debugging graphs that have passed through dynamo several times.

The added unit test demonstrates the added functionality of this PR.

~This PR is currently a proof-of-concept implementation that shows that preserving node metadata across dynamo is possible.~ This PR preserves node metadata across dynamo by doing the following:
- ~inject a counter variable into the `GraphModule` source code, which is incremented every time a node is run~
- Construct a line number -> node index map in `GraphModule` as the source code is being generated.
- pass a list of node metadata and the line number map to dynamo's bytecode analyzer
- ~dynamo traces the counter as a `ConstantVariable`, so when we create a new proxy, we can determine which original node index this proxy corresponds by looking at the value of the traced counter~
- When we create a new proxy, get the current instruction's line number, and get the node index using the line number map
- index into the original node metadata ~using the counter variable's tracked value.~

~Some things that should be addressed off the top of my head:~
- ~Is this feature even desirable? (Do we really want Dynamo to have special behavior for `GraphModules`? Should we expect users to re-export `GraphModules`?)~
- ~Is there a better approach than to use a counter? We considered using node names, line numbers, and assuming that proxies are created in the same order as the nodes, but each of these 3 have shortcomings. For node names, we only have access to new node names, not the old ones. Using line number is fragile. The third is problematic since not all created nodes go through `create_proxy` (e.g. inputs). We currently generate a line number to node index map when the `GraphModule`'s code is generated.~
- ~What's the best way to send data across the "CPython gap"? That is, it is not obvious how to cleanly pass data from dynamo's `eval_frame.py:_TorchDynamoContext.__call__` to `symbolic_convert.py:InstructionTranslatorBase.__init__`. In this PR, we use a global.~

Differential Revision: [D49257108](https://our.internmc.facebook.com/intern/diff/D49257108)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107067
Approved by: https://github.com/jansel
2023-09-15 23:29:14 +00:00
PyTorch MergeBot
c5e7588613 Revert "[dynamo] preserve some FX node metadata of GraphModules (#107067)"
This reverts commit 1d42148fee.

Reverted https://github.com/pytorch/pytorch/pull/107067 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/107067#issuecomment-1717321061))
2023-09-13 09:59:33 +00:00
Michael Voznesensky
de0b18fad9 Use user directed names for variables where possible (#109092)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109092
Approved by: https://github.com/ezyang
ghstack dependencies: #108846
2023-09-13 07:44:04 +00:00
William Wen
1d42148fee [dynamo] preserve some FX node metadata of GraphModules (#107067)
Requested from @tugsbayasgalan: we want dynamo to preserve some FX node metadata when we trace `GraphModule`s (`nn_module_stack`, `source_fn`, `stack_trace`). This is helpful for the case when we export an aten-level `GraphModule`, add some (possibly non-torch or non-aten) ops, and we want to transform the graph back into an aten-level graph. Without preserving metadata, future passes that look at metadata (e.g. quantization passes) won't work.

This feature also has the additional benefit of being able to preserve origin line of code when `print_readable`'ing a `GraphModule`. This is helpful when debugging graphs that have passed through dynamo several times.

The added unit test demonstrates the added functionality of this PR.

~This PR is currently a proof-of-concept implementation that shows that preserving node metadata across dynamo is possible.~ This PR preserves node metadata across dynamo by doing the following:
- ~inject a counter variable into the `GraphModule` source code, which is incremented every time a node is run~
- Construct a line number -> node index map in `GraphModule` as the source code is being generated.
- pass a list of node metadata and the line number map to dynamo's bytecode analyzer
- ~dynamo traces the counter as a `ConstantVariable`, so when we create a new proxy, we can determine which original node index this proxy corresponds by looking at the value of the traced counter~
- When we create a new proxy, get the current instruction's line number, and get the node index using the line number map
- index into the original node metadata ~using the counter variable's tracked value.~

~Some things that should be addressed off the top of my head:~
- ~Is this feature even desirable? (Do we really want Dynamo to have special behavior for `GraphModules`? Should we expect users to re-export `GraphModules`?)~
- ~Is there a better approach than to use a counter? We considered using node names, line numbers, and assuming that proxies are created in the same order as the nodes, but each of these 3 have shortcomings. For node names, we only have access to new node names, not the old ones. Using line number is fragile. The third is problematic since not all created nodes go through `create_proxy` (e.g. inputs). We currently generate a line number to node index map when the `GraphModule`'s code is generated.~
- ~What's the best way to send data across the "CPython gap"? That is, it is not obvious how to cleanly pass data from dynamo's `eval_frame.py:_TorchDynamoContext.__call__` to `symbolic_convert.py:InstructionTranslatorBase.__init__`. In this PR, we use a global.~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107067
Approved by: https://github.com/jansel
2023-09-11 17:11:51 +00:00
vasiliy
61fe49b8ed pt2: make aot_eager backend handle basic float8 operations (#107783)
Summary:

Reland of https://github.com/pytorch/pytorch/pull/107642 with a fix for tests on Windows.

Makes aot_eager backend of torch.compile handle basic float8 operations.

This is useful for float8 training UX.

Test Plan:

```
python test/test_quantization.py -k test_pt2_traceable_aot_eager
```

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107783
Approved by: https://github.com/albanD
2023-08-23 18:10:53 +00:00
PyTorch MergeBot
5025fb9213 Revert "pt2: make aot_eager backend handle basic float8 operations (#107642)"
This reverts commit 24147a8e1c.

Reverted https://github.com/pytorch/pytorch/pull/107642 on behalf of https://github.com/huydhn due to Sorry for reverting this, but it is failing Windows CPU test in trunk. The Windows failures on your PR looks related I think ([comment](https://github.com/pytorch/pytorch/pull/107642#issuecomment-1688999380))
2023-08-22 22:17:36 +00:00
vasiliy
24147a8e1c pt2: make aot_eager backend handle basic float8 operations (#107642)
Summary:

Makes aot_eager backend of torch.compile handle basic float8 operations.

This is useful for float8 training UX.

Test Plan:

```
python test/test_quantization.py -k test_pt2_traceable_aot_eager
```

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107642
Approved by: https://github.com/albanD
2023-08-22 18:57:14 +00:00
Tugsbayasgalan Manlaibaatar
4c46ea583f [Export] Support re-exportability (#106531)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106531
Approved by: https://github.com/zhxchen17
2023-08-03 18:27:26 +00:00
Tugsbayasgalan (Tugsuu) Manlaibaatar
5666d20bb8 Add unlifting pass under private config (#104897)
Summary: We wanna do this little by little. For now, I tried only on DissectedPartsModel which needs to use aot_export version.

Test Plan: CI

Reviewed By: zhxchen17

Differential Revision: D46785735

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104897
Approved by: https://github.com/JacobSzwejbka
2023-07-19 01:16:35 +00:00
Edward Z. Yang
666aeaa313 Preserve original co_filename when FX symbolic_trace (#103885)
Previously, you'd get `<eval_with_key>.0`; now you get `<eval_with_key>.0 from /data/users/ezyang/b/pytorch/test/dynamo/test_misc.py:5683 in forward`

I used to do this with globals, but now I do it with a `co_fields` parameter that's plumbed around, because putting things in globals has implications(TM). Happy to bikeshed on the `co_fields` structure.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103885
Approved by: https://github.com/albanD
2023-07-05 22:00:05 +00:00
Kunal Vaishnavi
709c9b5c93 Fix tabulate import error (#104468)
### Description

This PR fixes issue #104166 by re-raising the exception.

### Context

The `tabulate` package needs to be installed with `pip install tabulate` before calling `tabulate(...)`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104468
Approved by: https://github.com/Skylion007, https://github.com/BowenBao
2023-07-03 21:55:53 +00:00