Commit Graph

157 Commits

Author SHA1 Message Date
Jeff Daily
01abb5af21 additional support for float8_e4m3fnuz and _e5m2fnuz (#115214)
Follow up to #107586.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115214
Approved by: https://github.com/peterbell10, https://github.com/malfet
2024-01-22 18:33:41 +00:00
PyTorch MergeBot
b637fdc8b3 Revert "additional support for float8_e4m3fnuz and _e5m2fnuz (#115214)"
This reverts commit 74e1362499.

Reverted https://github.com/pytorch/pytorch/pull/115214 on behalf of https://github.com/PaliC due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/115214#issuecomment-1900815152))
2024-01-19 17:35:04 +00:00
Jeff Daily
74e1362499 additional support for float8_e4m3fnuz and _e5m2fnuz (#115214)
Follow up to #107586.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115214
Approved by: https://github.com/peterbell10
2024-01-19 00:50:18 +00:00
Edward Z. Yang
61a181e83c Report function name in stack trace annotations (#117459)
When working with internal flows, it can sometimes be ambiguous what
version of the code they are working with.  In this case, having the
function name available in the stack trace can help identify what you
are looking at.

Example now looks like:

```
[DEBUG]         # File: /data/users/ezyang/a/pytorch/a.py:5 in f, code: return x + x
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117459
Approved by: https://github.com/Skylion007
2024-01-15 00:29:13 +00:00
Aaron Gokaslan
3fe437b24b [BE]: Update flake8 to v6.1.0 and fix lints (#116591)
Updates flake8 to v6.1.0 and fixes a few lints using sed and some ruff tooling.
- Replace `assert(0)` with `raise AssertionError()`
- Remove extraneous parenthesis i.e.
  - `assert(a == b)` -> `assert a == b`
  - `if(x > y or y < z):`->`if x > y or y < z:`
  - And `return('...')` -> `return '...'`

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116591
Approved by: https://github.com/albanD, https://github.com/malfet
2024-01-03 06:04:44 +00:00
Xuehai Pan
199e07f108 [pytree][BE] update treespec num_children access (#116370)
Change `len(treespec.children_spes) -> treespec.num_children`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116370
Approved by: https://github.com/Skylion007
2023-12-24 20:54:32 +00:00
PyTorch MergeBot
684ce1b21d Revert "Assert that output could only be the last node of the FX graph (#115179)"
This reverts commit 4a9fb9832a.

Reverted https://github.com/pytorch/pytorch/pull/115179 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/115179#issuecomment-1845776365))
2023-12-07 17:26:27 +00:00
Oguz Ulgen
4a9fb9832a Assert that output could only be the last node of the FX graph (#115179)
Test Plan: unit tests

Differential Revision: D51856848

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115179
Approved by: https://github.com/Chillee
2023-12-06 08:17:16 +00:00
PyTorch MergeBot
8bb3cd192f Revert "Assert that output could only be the last node of the FX graph (#114973)"
This reverts commit a85df9eb0b.

Reverted https://github.com/pytorch/pytorch/pull/114973 on behalf of https://github.com/atalman due to Diff broke internal tests ([comment](https://github.com/pytorch/pytorch/pull/114973#issuecomment-1839290400))
2023-12-04 19:07:48 +00:00
Oguz Ulgen
a85df9eb0b Assert that output could only be the last node of the FX graph (#114973)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114973
Approved by: https://github.com/Chillee
2023-12-01 23:04:19 +00:00
PyTorch MergeBot
7c8d3639cf Revert "[fx] log the node when it's get eliminated (#112684)"
This reverts commit 6256d3710e.

Reverted https://github.com/pytorch/pytorch/pull/112684 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/112684#issuecomment-1831198778))
2023-11-29 04:31:15 +00:00
Shiyan Deng
6256d3710e [fx] log the node when it's get eliminated (#112684)
Summary: ATT

Test Plan: CI

Reviewed By: strisunshinewentingwang

Differential Revision: D50912413

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112684
Approved by: https://github.com/zyan0
2023-11-29 01:43:04 +00:00
Kaichao You
958f755a0e [FX][CodeGen] Make sure fx code is valid in python (#113345)
This PR fixes two cases when fx generated code is invalid in python (syntax error):

1. multiple type annotation in one line: `var1: annotation1, var2: annotation2 = function_call()`
2. invalid type annotation for scalars like `var1: f32[] = function_call()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113345
Approved by: https://github.com/ezyang
2023-11-10 21:12:16 +00:00
Peter Bell
66c32d099a Use pytree.arg_tree_leaves everywhere (#112394)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112394
Approved by: https://github.com/lezcano
ghstack dependencies: #112391, #112392, #112393
2023-10-31 15:57:06 +00:00
Peter Bell
bbd5b935e4 Use pytree.tree_leaves everywhere (#112324)
This changes all the instances I could find of `tree_flatten(...)[0]` or
`x, _ = tree_flatten` to use `tree_leaves`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112324
Approved by: https://github.com/lezcano
ghstack dependencies: #112327, #112323
2023-10-30 03:39:04 +00:00
Wenting Wang
675df7520a [tgif][multiforward] allow codegen to generate different func name (#111446)
Summary: see Shiyan's design doc for ATM TS publish weights dedupe https://fb.quip.com/HnUVAjUMaXMQ

Test Plan: tested in N4454041 after D50341352 that multiforward method is working for ts model

Differential Revision: D45750812

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111446
Approved by: https://github.com/842974287
2023-10-19 21:19:30 +00:00
willfengg
772e104dfd [inductor] visualize fused ops in svg graph (#107752)
example usage
* `TORCH_COMPILE_DEBUG=1 INDUCTOR_ORIG_FX_SVG=1 INDUCTOR_POST_FUSION_SVG=1 python trig.py`: show original fx node name, file, and code. see snapshot 2 where we have origin_0, 1, 2
* trig.py can be found in P816304818

Implementation
* keep original fx graph in GraphLowering, ```self.orig_gm: torch.fx.GraphModule = gm.__copy__()```
* draw original fx graph with origins ir_post_fusion ```V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes)```. node.meta["buff_meta"] tracks buf_name

<img width="350" alt="Screenshot 2023-08-29 at 12 40 24 PM" src="https://github.com/pytorch/pytorch/assets/134637289/c4e197cb-ab3b-4a09-a584-c1356376accb">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107752
Approved by: https://github.com/mlazos
2023-09-21 08:03:05 +00:00
Yukio Siraichi
6e3a7473cf Trace calls with Python Enum values. (#109507)
Fix: #82135
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109507
Approved by: https://github.com/ezyang
2023-09-20 22:18:11 +00:00
William Wen
b904432e82 [dynamo] preserve some FX node metadata of GraphModules (#107067)
Requested from @tugsbayasgalan: we want dynamo to preserve some FX node metadata when we trace `GraphModule`s (`nn_module_stack`, `source_fn`, `stack_trace`). This is helpful for the case when we export an aten-level `GraphModule`, add some (possibly non-torch or non-aten) ops, and we want to transform the graph back into an aten-level graph. Without preserving metadata, future passes that look at metadata (e.g. quantization passes) won't work.

This feature also has the additional benefit of being able to preserve origin line of code when `print_readable`'ing a `GraphModule`. This is helpful when debugging graphs that have passed through dynamo several times.

The added unit test demonstrates the added functionality of this PR.

~This PR is currently a proof-of-concept implementation that shows that preserving node metadata across dynamo is possible.~ This PR preserves node metadata across dynamo by doing the following:
- ~inject a counter variable into the `GraphModule` source code, which is incremented every time a node is run~
- Construct a line number -> node index map in `GraphModule` as the source code is being generated.
- pass a list of node metadata and the line number map to dynamo's bytecode analyzer
- ~dynamo traces the counter as a `ConstantVariable`, so when we create a new proxy, we can determine which original node index this proxy corresponds by looking at the value of the traced counter~
- When we create a new proxy, get the current instruction's line number, and get the node index using the line number map
- index into the original node metadata ~using the counter variable's tracked value.~

~Some things that should be addressed off the top of my head:~
- ~Is this feature even desirable? (Do we really want Dynamo to have special behavior for `GraphModules`? Should we expect users to re-export `GraphModules`?)~
- ~Is there a better approach than to use a counter? We considered using node names, line numbers, and assuming that proxies are created in the same order as the nodes, but each of these 3 have shortcomings. For node names, we only have access to new node names, not the old ones. Using line number is fragile. The third is problematic since not all created nodes go through `create_proxy` (e.g. inputs). We currently generate a line number to node index map when the `GraphModule`'s code is generated.~
- ~What's the best way to send data across the "CPython gap"? That is, it is not obvious how to cleanly pass data from dynamo's `eval_frame.py:_TorchDynamoContext.__call__` to `symbolic_convert.py:InstructionTranslatorBase.__init__`. In this PR, we use a global.~

Differential Revision: [D49257108](https://our.internmc.facebook.com/intern/diff/D49257108)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107067
Approved by: https://github.com/jansel
2023-09-15 23:29:14 +00:00
PyTorch MergeBot
c5e7588613 Revert "[dynamo] preserve some FX node metadata of GraphModules (#107067)"
This reverts commit 1d42148fee.

Reverted https://github.com/pytorch/pytorch/pull/107067 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/107067#issuecomment-1717321061))
2023-09-13 09:59:33 +00:00
Michael Voznesensky
de0b18fad9 Use user directed names for variables where possible (#109092)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109092
Approved by: https://github.com/ezyang
ghstack dependencies: #108846
2023-09-13 07:44:04 +00:00
William Wen
1d42148fee [dynamo] preserve some FX node metadata of GraphModules (#107067)
Requested from @tugsbayasgalan: we want dynamo to preserve some FX node metadata when we trace `GraphModule`s (`nn_module_stack`, `source_fn`, `stack_trace`). This is helpful for the case when we export an aten-level `GraphModule`, add some (possibly non-torch or non-aten) ops, and we want to transform the graph back into an aten-level graph. Without preserving metadata, future passes that look at metadata (e.g. quantization passes) won't work.

This feature also has the additional benefit of being able to preserve origin line of code when `print_readable`'ing a `GraphModule`. This is helpful when debugging graphs that have passed through dynamo several times.

The added unit test demonstrates the added functionality of this PR.

~This PR is currently a proof-of-concept implementation that shows that preserving node metadata across dynamo is possible.~ This PR preserves node metadata across dynamo by doing the following:
- ~inject a counter variable into the `GraphModule` source code, which is incremented every time a node is run~
- Construct a line number -> node index map in `GraphModule` as the source code is being generated.
- pass a list of node metadata and the line number map to dynamo's bytecode analyzer
- ~dynamo traces the counter as a `ConstantVariable`, so when we create a new proxy, we can determine which original node index this proxy corresponds by looking at the value of the traced counter~
- When we create a new proxy, get the current instruction's line number, and get the node index using the line number map
- index into the original node metadata ~using the counter variable's tracked value.~

~Some things that should be addressed off the top of my head:~
- ~Is this feature even desirable? (Do we really want Dynamo to have special behavior for `GraphModules`? Should we expect users to re-export `GraphModules`?)~
- ~Is there a better approach than to use a counter? We considered using node names, line numbers, and assuming that proxies are created in the same order as the nodes, but each of these 3 have shortcomings. For node names, we only have access to new node names, not the old ones. Using line number is fragile. The third is problematic since not all created nodes go through `create_proxy` (e.g. inputs). We currently generate a line number to node index map when the `GraphModule`'s code is generated.~
- ~What's the best way to send data across the "CPython gap"? That is, it is not obvious how to cleanly pass data from dynamo's `eval_frame.py:_TorchDynamoContext.__call__` to `symbolic_convert.py:InstructionTranslatorBase.__init__`. In this PR, we use a global.~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107067
Approved by: https://github.com/jansel
2023-09-11 17:11:51 +00:00
vasiliy
61fe49b8ed pt2: make aot_eager backend handle basic float8 operations (#107783)
Summary:

Reland of https://github.com/pytorch/pytorch/pull/107642 with a fix for tests on Windows.

Makes aot_eager backend of torch.compile handle basic float8 operations.

This is useful for float8 training UX.

Test Plan:

```
python test/test_quantization.py -k test_pt2_traceable_aot_eager
```

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107783
Approved by: https://github.com/albanD
2023-08-23 18:10:53 +00:00
PyTorch MergeBot
5025fb9213 Revert "pt2: make aot_eager backend handle basic float8 operations (#107642)"
This reverts commit 24147a8e1c.

Reverted https://github.com/pytorch/pytorch/pull/107642 on behalf of https://github.com/huydhn due to Sorry for reverting this, but it is failing Windows CPU test in trunk. The Windows failures on your PR looks related I think ([comment](https://github.com/pytorch/pytorch/pull/107642#issuecomment-1688999380))
2023-08-22 22:17:36 +00:00
vasiliy
24147a8e1c pt2: make aot_eager backend handle basic float8 operations (#107642)
Summary:

Makes aot_eager backend of torch.compile handle basic float8 operations.

This is useful for float8 training UX.

Test Plan:

```
python test/test_quantization.py -k test_pt2_traceable_aot_eager
```

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107642
Approved by: https://github.com/albanD
2023-08-22 18:57:14 +00:00
Tugsbayasgalan Manlaibaatar
4c46ea583f [Export] Support re-exportability (#106531)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106531
Approved by: https://github.com/zhxchen17
2023-08-03 18:27:26 +00:00
Tugsbayasgalan (Tugsuu) Manlaibaatar
5666d20bb8 Add unlifting pass under private config (#104897)
Summary: We wanna do this little by little. For now, I tried only on DissectedPartsModel which needs to use aot_export version.

Test Plan: CI

Reviewed By: zhxchen17

Differential Revision: D46785735

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104897
Approved by: https://github.com/JacobSzwejbka
2023-07-19 01:16:35 +00:00
Edward Z. Yang
666aeaa313 Preserve original co_filename when FX symbolic_trace (#103885)
Previously, you'd get `<eval_with_key>.0`; now you get `<eval_with_key>.0 from /data/users/ezyang/b/pytorch/test/dynamo/test_misc.py:5683 in forward`

I used to do this with globals, but now I do it with a `co_fields` parameter that's plumbed around, because putting things in globals has implications(TM). Happy to bikeshed on the `co_fields` structure.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103885
Approved by: https://github.com/albanD
2023-07-05 22:00:05 +00:00
Kunal Vaishnavi
709c9b5c93 Fix tabulate import error (#104468)
### Description

This PR fixes issue #104166 by re-raising the exception.

### Context

The `tabulate` package needs to be installed with `pip install tabulate` before calling `tabulate(...)`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104468
Approved by: https://github.com/Skylion007, https://github.com/BowenBao
2023-07-03 21:55:53 +00:00
PyTorch MergeBot
29e3fddb08 Revert "Preserve original co_filename when FX symbolic_trace (#103885)"
This reverts commit b9f81a483a.

Reverted https://github.com/pytorch/pytorch/pull/103885 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/103885#issuecomment-1603612781))
2023-06-23 02:49:04 +00:00
Edward Z. Yang
b9f81a483a Preserve original co_filename when FX symbolic_trace (#103885)
Previously, you'd get `<eval_with_key>.0`; now you get `<eval_with_key>.0 from /data/users/ezyang/b/pytorch/test/dynamo/test_misc.py:5683 in forward`

I used to do this with globals, but now I do it with a `co_fields` parameter that's plumbed around, because putting things in globals has implications(TM). Happy to bikeshed on the `co_fields` structure.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103885
Approved by: https://github.com/albanD
2023-06-21 08:28:50 +00:00
Michael Suo
a475ea4542 [fx] change from #users to num_users in graph printout (#101140)
`#users` means stuff in various chat apps, which makes it annoying to copypasta graphs into them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101140
Approved by: https://github.com/ezyang
2023-06-20 21:24:32 +00:00
Animesh Jain
168ae806d0 [fx] Fix repr when arg is an OpOverload (#102547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102547
Approved by: https://github.com/Skylion007, https://github.com/jansel
2023-05-30 21:11:05 +00:00
PyTorch MergeBot
66eef31444 Revert "[fx] change from #users to num_users in graph printout (#101140)"
This reverts commit e568c5a18d.

Reverted https://github.com/pytorch/pytorch/pull/101140 on behalf of https://github.com/jeanschmidt due to There are internal changes to this commit that are preventing landing, so I am reverting to unblock the diff train ([comment](https://github.com/pytorch/pytorch/pull/101140#issuecomment-1547989487))
2023-05-15 14:35:22 +00:00
Michael Suo
e568c5a18d [fx] change from #users to num_users in graph printout (#101140)
`#users` means stuff in various chat apps, which makes it annoying to copypasta graphs into them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101140
Approved by: https://github.com/ezyang
2023-05-12 04:34:01 +00:00
Jason Ansel
176ef97fc1 [inductor] Fix bug where a node gets erased twice (#100848)
Fixes #100806

The underlying bug is if you erase an FX node twice, everything runs without error, but `len(graph.nodes)` reports the incorrect value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100848
Approved by: https://github.com/ngimel, https://github.com/Skylion007
2023-05-08 18:24:19 +00:00
Kazuaki Ishizaki
105ef68f72 Fix typos under torch/fx directory (#97596)
This PR fixes typos in comments and messages of `.py` files under `torch/fx` directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97596
Approved by: https://github.com/dagitses, https://github.com/kit1980
2023-04-10 21:57:36 +00:00
Sherlock Huang
f8692dcc4a Node.stack_trace should have innermost frame last (#95592)
Both fx.Tracer and Dynamo should store node.stack_trace in the "innermost frame last" order.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95592
Approved by: https://github.com/ezyang
2023-02-28 02:14:40 +00:00
Angela Yi
d301caa890 Deepcopy output node metadata (#95426)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95426
Approved by: https://github.com/SherlockNoMad
2023-02-27 15:25:54 +00:00
Aaron Gokaslan
8fce9a09cd [BE]: pyupgrade Python to 3.8 - imports and object inheritance only (#94308)
Apply parts of pyupgrade to torch (starting with the safest changes).
This PR only does two things: removes the need to inherit from object and removes unused future imports.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-07 21:10:56 +00:00
min-jean-cho
900e09c872 [Dynamo] Support torch.Tensor.fn as TorchVariable, not UserDefinedObjectVariable, preventing graph break (#93243)
As found in #92709, thanks to @ngimel and @jansel, currently `torch.Tensor.fn` points to `UserDefinedObjectVariable` rather than `TorchVariable`. The root cause is due to https://github.com/pytorch/pytorch/pull/92709#pullrequestreview-1273357406. To prevent this, build `TorchVariable`  of `torch.Tensor.fn` pointing to `torch.ops.aten.fn`.

This issue propagates to `torch.Tensor.fn` causing graph break with `nopython=True`.
```python
import torch
import torch._dynamo as dynamo

#op = torch.ops.aten.abs_ # no graph break
op = torch.Tensor.abs_ # graph break
args = torch.empty(10)

def foo(args):
    return op(args)

opt_foo = dynamo.optimize("inductor", nopython=True)(foo)
y_ = opt_foo(args)

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93243
Approved by: https://github.com/jansel
2023-02-07 09:26:50 +00:00
Thiago Crepaldi
95dfad9d93 Add kwargs support to torch.export() API (#92013)
Fixes [#1997](https://github.com/pytorch/torchdynamo/issues/1997)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92013
Approved by: https://github.com/jansel
2023-01-27 01:58:51 +00:00
PyTorch MergeBot
1a98c3e36c Revert "Add kwargs support to torch.export() API (#92013)"
This reverts commit 890b68281a.

Reverted https://github.com/pytorch/pytorch/pull/92013 on behalf of https://github.com/DanilBaibak due to Break internal build
2023-01-16 13:03:48 +00:00
Thiago Crepaldi
890b68281a Add kwargs support to torch.export() API (#92013)
Fixes [#1997](https://github.com/pytorch/torchdynamo/issues/1997)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92013
Approved by: https://github.com/jansel
2023-01-13 15:17:26 +00:00
Sherlock Huang
eeba9d5ab4 Preserve node's meta during fx.transformation (#90737)
We wish to preserve node.meta over fx.Transformer transformation and aot_autograd. This will preserve all the meta fields in the original node, including stack_trace, nn_module_stack, val, tensor_meta...

Sample

Here's a graph produced by Dynamo.
```
class GraphModule(torch.nn.Module):
    def forward(self, x : torch.Tensor, y : torch.Tensor):
        # File: /scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py:35, code: a = torch.cos(x)
        cos = torch.cos(x);  x = None

        # File: /scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py:36, code: b = torch.sin(y)
        sin = torch.sin(y);  y = None

        # File: /scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py:37, code: return a + b
        add = cos + sin;  cos = sin = None
        return (add,)

x {'creation_timestamp': 0, 'stack_trace': '  File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 45, in forward\n    def forward(self, x, y):\n'}
y {'creation_timestamp': 0, 'stack_trace': '  File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 45, in forward\n    def forward(self, x, y):\n'}
cos {'creation_timestamp': 3, 'nn_module_stack': {'self_block': "<class '__main__.Block'>"}, 'stack_trace': '  File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 35, in forward\n    a = torch.cos(x)\n |   File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 46, in forward\n    return self.block(x, y)\n'}
sin {'creation_timestamp': 4, 'nn_module_stack': {'self_block': "<class '__main__.Block'>"}, 'stack_trace': '  File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 36, in forward\n    b = torch.sin(y)\n |   File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 46, in forward\n    return self.block(x, y)\n'}
add {'creation_timestamp': 4, 'nn_module_stack': {'self_block': "<class '__main__.Block'>"}, 'stack_trace': '  File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 37, in forward\n    return a + b\n |   File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 46, in forward\n    return self.block(x, y)\n'}
output {'creation_timestamp': 4}
```

After lowering to aten graph with aot_autograd_simplified()
```
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: f32[2, 3], primals_2: f32[2, 3]):
        # File: /scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py:35, code: a = torch.cos(x)
        cos: f32[2, 3] = torch.ops.aten.cos.default(primals_1)

        # File: /scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py:36, code: b = torch.sin(y)
        sin: f32[2, 3] = torch.ops.aten.sin.default(primals_2)

        # File: /scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py:37, code: return a + b
        add: f32[2, 3] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
        return [add, primals_2, primals_1]

primals_1 {'val': FakeTensor(FakeTensor(..., device='meta', size=(2, 3)), cpu), 'tensor_meta': TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=True, stride=(3, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}
primals_2 {'val': FakeTensor(FakeTensor(..., device='meta', size=(2, 3)), cpu), 'tensor_meta': TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=True, stride=(3, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}
cos {'creation_timestamp': 3, 'nn_module_stack': {'self_block': "<class '__main__.Block'>"}, 'stack_trace': '  File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 35, in forward\n    a = torch.cos(x)\n |   File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 46, in forward\n    return self.block(x, y)\n', 'val': FakeTensor(FakeTensor(..., device='meta', size=(2, 3)), cpu), 'tensor_meta': TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}
sin {'creation_timestamp': 4, 'nn_module_stack': {'self_block': "<class '__main__.Block'>"}, 'stack_trace': '  File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 36, in forward\n    b = torch.sin(y)\n |   File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 46, in forward\n    return self.block(x, y)\n', 'val': FakeTensor(FakeTensor(..., device='meta', size=(2, 3)), cpu), 'tensor_meta': TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}
add {'creation_timestamp': 4, 'nn_module_stack': {'self_block': "<class '__main__.Block'>"}, 'stack_trace': '  File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 37, in forward\n    return a + b\n |   File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 46, in forward\n    return self.block(x, y)\n', 'val': FakeTensor(FakeTensor(..., device='meta', size=(2, 3)), cpu), 'tensor_meta': TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}
output {}
```

Notice that output fx node have creation_time_stamp, nn_module_stack and stack_trace copied from the original fx node.
val and tensor_meta were latter populated by a subsequent fake_tensor_propagation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90737
Approved by: https://github.com/jerryzh168
2023-01-06 17:21:02 +00:00
Tugsbayasgalan (Tugsuu) Manlaibaatar
f660d62ddc Make dynamo.export preserve user input/output format (#90884)
Currently, dynamo flattens the user input so when user reuses the input they use for tracing, exported graph wouldn't work as it would expect flat args.  This PR changes this behaviour by explicitly wrapping the dynamo produced graph with correct user input/output format.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90884
Approved by: https://github.com/zhxchen17, https://github.com/voznesenskym
2022-12-16 00:57:09 +00:00
Edward Z. Yang
d1ee073041 Handle case when candidate is empty (#88359)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88359
Approved by: https://github.com/wconstab
2022-11-05 17:19:40 +00:00
Edward Z. Yang
1ff52225f1 Unify SymIntNode and SymFloatNode into SymNode (#87817)
This refactor was prompted by challenges handling mixed int/float
operations in C++.  A previous version of this patch
added overloads for each permutation of int/float and was unwieldy
https://github.com/pytorch/pytorch/pull/87722/  This PR takes a different
approach.

The general outline of the patch is to combine the C++ types SymIntNode
and SymFloatNode into a single type, SymNode.  This is type erased; we
no longer know statically at C++ if we have an int/float and have to test
it with the is_int()/is_float() virtual methods.  This has a number of
knock on effects.

- We no longer have C++ classes to bind to Python.  Instead, we take an
  entirely new approach to our Python API, where we have a SymInt/SymFloat
  class defined entirely in Python, which hold a SymNode (which corresponds
  to the C++ SymNode).  However, SymNode is not pybind11-bound; instead,
  it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode
  when it goes into C++.  This implies a userland rename.

  In principle, it is also possible for the canonical implementation of SymNode
  to be written in C++, and then bound to Python with pybind11 (we have
  this code, although it is commented out.)  However, I did not implement
  this as we currently have no C++ implementations of SymNode.

  Because we do return SymInt/SymFloat from C++ bindings, the C++ binding
  code needs to know how to find these classes.  Currently, this is done
  just by manually importing torch and getting the attributes.

- Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now
  takes SymInt/SymFloat, rather than SymNode, bringing it in line with how
  __torch_dispatch__ works.

Some miscellaneous improvements:

- SymInt now has a constructor that takes SymNode.  Note that this
  constructor is ambiguous if you pass in a subclass of SymNode,
  so an explicit downcast is necessary.  This means toSymFloat/toSymInt
  are no more.  This is a mild optimization as it means rvalue reference
  works automatically.

- We uniformly use the caster for c10::SymInt/SymFloat, rather than
  going the long way via the SymIntNode/SymFloatNode.

- Removed some unnecessary toSymInt/toSymFloat calls in normalize_*
  functions, pretty sure this doesn't do anything.

- guard_int is now a free function, since to guard on an int you cannot
  assume the method exists.  A function can handle both int and SymInt
  inputs.

- We clean up the magic method definition code for SymInt/SymFloat/SymNode.
  ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets
  plain methods; this is to help avoid confusion between the two types.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87817
Approved by: https://github.com/albanD, https://github.com/anjali411
2022-10-27 20:56:02 +00:00
Horace He
fc3beef5ac Fix stupid N^2 naming behavior in FX and removed assert that slows things a lot sometimes (#87533)
cc @jansel @lezcano @fdrocha @mlazos @soumith @voznesenskym @yanboliang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87533
Approved by: https://github.com/ezyang, https://github.com/voznesenskym
2022-10-23 08:26:37 +00:00
Yidi Wu
4dc579838b Allow fx.Graph.owning_module to be used as attribute. (#86822)
Summary:
The current behavior of owning_module setter is difficult to understand: it changes the owning_module to None if owners is not 0 but increments the owners count. If the owning_module is None, the owners count should be 0 as none of them is accessible. On the other hand, if the owners count increases, the owning_module should be a collection (e.g. a list).

This diff changes owning_module to be a normal attribute. The semantic is that graph can have **at most one** owning module and can be assigned to new module.

The alternative is to use a list to represent the owning_modules of a graph but it breaks backward compatibility and the exact use cases of having multiple owning_modules are not clear.

Test Plan: Test with CI.

Differential Revision: D40200624

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86822
Approved by: https://github.com/tugsbayasgalan
2022-10-19 00:12:59 +00:00