Commit Graph

122 Commits

Author SHA1 Message Date
chilli
c486e2ab64 Add coloring to fx graph print out (#128476)
Note: Won't land immediately, at least I'll need to add a color option to the field. But curious if any tests fail.

Old:
<img width="1294" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/c3a750ed-5e54-4621-b2e4-be5481be15b6">

New:
<img width="1303" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/3a1f1adc-6f3a-413e-8b87-ee53da9bf4ed">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128476
Approved by: https://github.com/ezyang
2024-06-13 23:39:04 +00:00
Aaron Orenstein
038b927590 Flip default value for mypy disallow_untyped_defs [7/11] (#127844)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127844
Approved by: https://github.com/oulgen
ghstack dependencies: #127842, #127843
2024-06-08 18:49:45 +00:00
Sheng Fu
bbeb0906c4 Register creak_node_hook (#126671)
Differential Revision: D57469157

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126671
Approved by: https://github.com/angelayi
2024-05-24 23:32:15 +00:00
Matthew Hoffman
81277baa0c Remove removed ruff rule TRY200 (#126256)
My TOML linter is complaining that "TRY200" is not acceptable for the `tool.ruff.lint` schema.

From the ruff docs: https://docs.astral.sh/ruff/rules/reraise-no-cause/

> This rule has been removed and its documentation is only available for historical reasons.
>
> This rule is identical to [B904](https://docs.astral.sh/ruff/rules/raise-without-from-inside-except/) which should be used instead.

and we are currently explicitly ignoring B904.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126256
Approved by: https://github.com/Skylion007
2024-05-17 16:31:05 +00:00
angelayi
8be4c1bc2f [export] Add metadata for nodes insert_deferred_runtime_asserts (#125414)
Fixes [internal error](https://fb.workplace.com/groups/1075192433118967/permalink/1416709435633930/).

The issue is that the asserting nodes added in the `insert_deferred_runtime_assertion` pass do not contain metadata that the ExportedProgram requires the graph to have. One solution to fix this is to retrace the entire module, or another solution is to manually add back this metadata.

This diff implements the latter solution (manually add back the metadata) through hooking into fx.graph's `create_node` function, and adding export-specific metadata for every node that is created. The reason I did this is so that the `insert_deferred_runtime_assertion` does not have to know about what metadata export wants.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125414
Approved by: https://github.com/zhxchen17, https://github.com/BoyuanFeng
2024-05-07 23:15:21 +00:00
Aaron Gokaslan
29cc293725 [BE]: FURB142 - Remove set mutations. Use set update (#124551)
Uses set mutation methods instead of manually reimplementing (update, set_difference etc).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124551
Approved by: https://github.com/ezyang
2024-04-21 14:12:33 +00:00
Sherlock Huang
c2f687f32c Option to include stride and device annotation in gm.print_readable() (#123690)
Summary:
Sample output for gm.print_readable(include_stride=True, include_device=True)

```
        getitem_21: "i32[1200][1]cuda:0" = auto_functionalized_4[1]
        copy_2: "f32[2, 60][60, 1]cuda:1"  = ....
```

Test Plan: CI

Differential Revision: D55949129

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123690
Approved by: https://github.com/Chillee
2024-04-11 06:53:10 +00:00
Simon Fan
f178d996a8 [dynamo] Fix traceback generation on runtime errors (#122746)
Fixes `During handling of the above exception, another exception occurred: [...] torch._dynamo.exc.Unsupported: generator`. traceback.format_exc uses generators which isn't supported by dynamo yet.
<details>
  <summary>current error message</summary>

```
======================================================================
ERROR: test_custom_fn_saved_tensors (__main__.TestCompiledAutograd)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/xmfan/core/pytorch/torch/fx/graph_module.py", line 307, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/home/xmfan/core/pytorch/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/xmfan/core/pytorch/torch/nn/modules/module.py", line 1537, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.0", line 4, in forward
    def forward(self, inputs, sizes, hooks):
IndexError: list index out of range

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/xmfan/core/pytorch/torch/testing/_internal/common_utils.py", line 2741, in wrapper
    method(*args, **kwargs)
  File "/home/xmfan/core/pytorch/test/inductor/test_compiled_autograd.py", line 499, in test_custom_fn_saved_tensors
    self.check_output_and_recompiles(fn, 1)
  File "/home/xmfan/core/pytorch/test/inductor/test_compiled_autograd.py", line 61, in check_output_and_recompiles
    actual = list(opt_fn())
  File "/home/xmfan/core/pytorch/test/inductor/test_compiled_autograd.py", line 495, in fn
    loss.backward()
  File "/home/xmfan/core/pytorch/torch/_tensor.py", line 534, in backward
    torch.autograd.backward(
  File "/home/xmfan/core/pytorch/torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File "/home/xmfan/core/pytorch/torch/autograd/graph.py", line 766, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/xmfan/core/pytorch/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/xmfan/core/pytorch/torch/nn/modules/module.py", line 1537, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/xmfan/core/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
    res = fn(*args, **kwargs)
  File "/home/xmfan/core/pytorch/torch/fx/graph_module.py", line 741, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/home/xmfan/core/pytorch/torch/fx/graph_module.py", line 315, in __call__
    _WrappedCall._generate_error_message(topmost_framesummary),
  File "/home/xmfan/core/pytorch/torch/fx/graph_module.py", line 289, in _generate_error_message
    tb_repr = get_traceback()
  File "/home/xmfan/core/pytorch/torch/fx/graph_module.py", line 288, in get_traceback
    return traceback.format_exc()
  File "/home/xmfan/.conda/envs/benchmarks/lib/python3.10/traceback.py", line 183, in format_exc
    return "".join(format_exception(*sys.exc_info(), limit=limit, chain=chain))
  File "/home/xmfan/.conda/envs/benchmarks/lib/python3.10/traceback.py", line 136, in format_exception
    return list(te.format(chain=chain))
  File "/home/xmfan/core/pytorch/torch/_dynamo/convert_frame.py", line 941, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/home/xmfan/core/pytorch/torch/_dynamo/convert_frame.py", line 348, in _convert_frame_assert
    unimplemented("generator")
  File "/home/xmfan/core/pytorch/torch/_dynamo/exc.py", line 199, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: generator
```

</details>

With this change, we get back the descriptive error message:
<details>
  <summary>post-fix error message</summary>

```
Traceback (most recent call last):
  File "/home/xmfan/core/pytorch/torch/fx/graph_module.py", line 307, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/home/xmfan/core/pytorch/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/xmfan/core/pytorch/torch/nn/modules/module.py", line 1537, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.0", line 4, in forward
    def forward(self, inputs, sizes, hooks):
IndexError: list index out of range

Call using an FX-traced Module, line 4 of the traced Module's generated forward function:

def forward(self, inputs, sizes, hooks):

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    getitem = inputs[0]

    getitem_1 = inputs[1];  inputs = None
```

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122746
Approved by: https://github.com/jansel, https://github.com/anijain2305
ghstack dependencies: #122691
2024-03-28 14:40:54 +00:00
Oguz Ulgen
7c5e29ae71 Back out "Support triton.language.dtype with torch.compile (#121690)" (#122108)
Summary: Some hard to deal with package import/export related problems. Lets revert and start with clean slate.

Test Plan: CI

Differential Revision: D55024877

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122108
Approved by: https://github.com/ezyang
2024-03-18 20:50:28 +00:00
Oguz Ulgen
e39aedfcc5 Fix fx graph triton import bug (#122041)
Summary: Unless we register triton to be a special import, FX graph import mechanism imports it as `from fx-generated._0 import triton as triton` which is obviously broken.

Test Plan:
I could not figure out how to write a test for this but
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//tgif/lib/tests/gpu_tests:lowering_pass_test -- -r test_default_ait_lowering_multi_hardwares
```
now passes

Differential Revision: D54990782

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122041
Approved by: https://github.com/aakhundov
2024-03-17 22:48:51 +00:00
Shunting Zhang
fe10b1800f LazyGraphModule (#117911)
I feel it's easier to open a new PR rather than iterating on the previous PR (https://github.com/pytorch/pytorch/pull/105257 ) since this is more like a rewrite.

In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR.

The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time.

By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117911
Approved by: https://github.com/jansel
2024-01-27 04:10:18 +00:00
Zhengxu Chen
abd759d50d [fx] Add hooks to intercept node replacements. (#117825)
Summary: Adding an experimental API to FX graph module to place "hooks" every time when we are changing or replacing nodes in a graph, so that we can properly update the new name in graph signature and potentially other places.

Test Plan:
buck test mode/opt  -c fbcode.enable_gpu_sections=true caffe2/test/distributed/_tensor/experimental:tp_transform

buck test mode/opt caffe2/test:test_export -- -r test_replace_hook

Differential Revision: D52896531

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117825
Approved by: https://github.com/avikchaudhuri
2024-01-23 22:28:40 +00:00
Aaron Gokaslan
ee5d981249 [BE]: Enable RUFF PERF402 and apply fixes (#115505)
* Enable PERF402. Makes code more efficient and succinct by removing useless list copies that could be accomplished either via a list constructor or extend call. All test cases have noqa added since performance is not as sensitive in that folder.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115505
Approved by: https://github.com/malfet
2023-12-20 18:01:24 +00:00
Shiyan Deng
6e495eef60 [tgif] allow preserving non-forward methods during deepcopy (#114849)
Summary:
bypass-github-export-checks
force-merge-on-github

Reviewed By: sayitmemory

Differential Revision: D51629520

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114849
Approved by: https://github.com/houseroad
2023-12-01 21:51:05 +00:00
Shiyan Deng
fe7b845c8d [tgif] preserve non-forward method during torch package serialization (#114702)
Reviewed By: terrycsy, sayitmemory

Differential Revision: D51607058

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114702
Approved by: https://github.com/houseroad
2023-11-29 22:31:35 +00:00
Aaron Gokaslan
cb856b08b2 [BE]: Attach cause to some exceptions and enable RUFF TRY200 (#111496)
Did some easy fixes from enabling TRY200. Most of these seem like oversights instead of intentional. The proper way to silence intentional errors is with `from None` to note that you thought about whether it should contain the cause and decided against it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111496
Approved by: https://github.com/malfet
2023-10-19 21:56:36 +00:00
Wenting Wang
675df7520a [tgif][multiforward] allow codegen to generate different func name (#111446)
Summary: see Shiyan's design doc for ATM TS publish weights dedupe https://fb.quip.com/HnUVAjUMaXMQ

Test Plan: tested in N4454041 after D50341352 that multiforward method is working for ts model

Differential Revision: D45750812

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111446
Approved by: https://github.com/842974287
2023-10-19 21:19:30 +00:00
Aaron Gokaslan
a0632389b7 [BE]: Update lintrunner mypy to 1.6.0 (#111375)
Follow up to #111305 that updates lintrunner's version too.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111375
Approved by: https://github.com/malfet
2023-10-17 01:22:06 +00:00
Sam Larsen
0dfa354570 [inductor] Implement Fx graph caching to improve warm compilation time. (#103453)
Summary: Implement an on-disk cache to save and reuse compiled FX Graphs. This implementation does not handle tensors with symbolic shapes. This needs to be done in a follow-up PR.

Test Plan:
* New unit tests exercising saving and load from the cache.
* New unit tests to exercise the cache key calculations.
* Ran several benchmarks to see cache hit and resulting compilation times.

Differential Revision: [D50255289](https://our.internmc.facebook.com/intern/diff/D50255289)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103453
Approved by: https://github.com/eellison, https://github.com/Chillee
2023-10-13 13:33:56 +00:00
PyTorch MergeBot
7fbfa4e020 Revert "[inductor] Implement Fx graph caching to improve warm compilation time. (#103453)"
This reverts commit fc1105b282.

Reverted https://github.com/pytorch/pytorch/pull/103453 on behalf of https://github.com/kit1980 due to Same issue unfortunately, the newly added test fails on internal builds ([comment](https://github.com/pytorch/pytorch/pull/103453#issuecomment-1760202365))
2023-10-12 18:54:51 +00:00
Sam Larsen
fc1105b282 [inductor] Implement Fx graph caching to improve warm compilation time. (#103453)
Summary: Implement an on-disk cache to save and reuse compiled FX Graphs. This implementation does not handle tensors with symbolic shapes. This needs to be done in a follow-up PR.

Test Plan:
* New unit tests exercising saving and load from the cache.
* New unit tests to exercise the cache key calculations.
* Ran several benchmarks to see cache hit and resulting compilation times.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103453
Approved by: https://github.com/eellison, https://github.com/Chillee
2023-10-11 14:39:14 +00:00
PyTorch MergeBot
3100d3e661 Revert "[inductor] Implement Fx graph caching to improve warm compilation time. (#103453)"
This reverts commit 8a8668e1ae.

Reverted https://github.com/pytorch/pytorch/pull/103453 on behalf of https://github.com/kit1980 due to The newly added test fails on internal builds ([comment](https://github.com/pytorch/pytorch/pull/103453#issuecomment-1756449919))
2023-10-10 23:21:59 +00:00
Sam Larsen
8a8668e1ae [inductor] Implement Fx graph caching to improve warm compilation time. (#103453)
Summary: Implement an on-disk cache to save and reuse compiled FX Graphs. This implementation does not handle tensors with symbolic shapes. This needs to be done in a follow-up PR.

Test Plan:
* New unit tests exercising saving and load from the cache.
* New unit tests to exercise the cache key calculations.
* Ran several benchmarks to see cache hit and resulting compilation times.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103453
Approved by: https://github.com/eellison
2023-10-08 20:32:15 +00:00
William Wen
b904432e82 [dynamo] preserve some FX node metadata of GraphModules (#107067)
Requested from @tugsbayasgalan: we want dynamo to preserve some FX node metadata when we trace `GraphModule`s (`nn_module_stack`, `source_fn`, `stack_trace`). This is helpful for the case when we export an aten-level `GraphModule`, add some (possibly non-torch or non-aten) ops, and we want to transform the graph back into an aten-level graph. Without preserving metadata, future passes that look at metadata (e.g. quantization passes) won't work.

This feature also has the additional benefit of being able to preserve origin line of code when `print_readable`'ing a `GraphModule`. This is helpful when debugging graphs that have passed through dynamo several times.

The added unit test demonstrates the added functionality of this PR.

~This PR is currently a proof-of-concept implementation that shows that preserving node metadata across dynamo is possible.~ This PR preserves node metadata across dynamo by doing the following:
- ~inject a counter variable into the `GraphModule` source code, which is incremented every time a node is run~
- Construct a line number -> node index map in `GraphModule` as the source code is being generated.
- pass a list of node metadata and the line number map to dynamo's bytecode analyzer
- ~dynamo traces the counter as a `ConstantVariable`, so when we create a new proxy, we can determine which original node index this proxy corresponds by looking at the value of the traced counter~
- When we create a new proxy, get the current instruction's line number, and get the node index using the line number map
- index into the original node metadata ~using the counter variable's tracked value.~

~Some things that should be addressed off the top of my head:~
- ~Is this feature even desirable? (Do we really want Dynamo to have special behavior for `GraphModules`? Should we expect users to re-export `GraphModules`?)~
- ~Is there a better approach than to use a counter? We considered using node names, line numbers, and assuming that proxies are created in the same order as the nodes, but each of these 3 have shortcomings. For node names, we only have access to new node names, not the old ones. Using line number is fragile. The third is problematic since not all created nodes go through `create_proxy` (e.g. inputs). We currently generate a line number to node index map when the `GraphModule`'s code is generated.~
- ~What's the best way to send data across the "CPython gap"? That is, it is not obvious how to cleanly pass data from dynamo's `eval_frame.py:_TorchDynamoContext.__call__` to `symbolic_convert.py:InstructionTranslatorBase.__init__`. In this PR, we use a global.~

Differential Revision: [D49257108](https://our.internmc.facebook.com/intern/diff/D49257108)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107067
Approved by: https://github.com/jansel
2023-09-15 23:29:14 +00:00
PyTorch MergeBot
c5e7588613 Revert "[dynamo] preserve some FX node metadata of GraphModules (#107067)"
This reverts commit 1d42148fee.

Reverted https://github.com/pytorch/pytorch/pull/107067 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/107067#issuecomment-1717321061))
2023-09-13 09:59:33 +00:00
William Wen
1d42148fee [dynamo] preserve some FX node metadata of GraphModules (#107067)
Requested from @tugsbayasgalan: we want dynamo to preserve some FX node metadata when we trace `GraphModule`s (`nn_module_stack`, `source_fn`, `stack_trace`). This is helpful for the case when we export an aten-level `GraphModule`, add some (possibly non-torch or non-aten) ops, and we want to transform the graph back into an aten-level graph. Without preserving metadata, future passes that look at metadata (e.g. quantization passes) won't work.

This feature also has the additional benefit of being able to preserve origin line of code when `print_readable`'ing a `GraphModule`. This is helpful when debugging graphs that have passed through dynamo several times.

The added unit test demonstrates the added functionality of this PR.

~This PR is currently a proof-of-concept implementation that shows that preserving node metadata across dynamo is possible.~ This PR preserves node metadata across dynamo by doing the following:
- ~inject a counter variable into the `GraphModule` source code, which is incremented every time a node is run~
- Construct a line number -> node index map in `GraphModule` as the source code is being generated.
- pass a list of node metadata and the line number map to dynamo's bytecode analyzer
- ~dynamo traces the counter as a `ConstantVariable`, so when we create a new proxy, we can determine which original node index this proxy corresponds by looking at the value of the traced counter~
- When we create a new proxy, get the current instruction's line number, and get the node index using the line number map
- index into the original node metadata ~using the counter variable's tracked value.~

~Some things that should be addressed off the top of my head:~
- ~Is this feature even desirable? (Do we really want Dynamo to have special behavior for `GraphModules`? Should we expect users to re-export `GraphModules`?)~
- ~Is there a better approach than to use a counter? We considered using node names, line numbers, and assuming that proxies are created in the same order as the nodes, but each of these 3 have shortcomings. For node names, we only have access to new node names, not the old ones. Using line number is fragile. The third is problematic since not all created nodes go through `create_proxy` (e.g. inputs). We currently generate a line number to node index map when the `GraphModule`'s code is generated.~
- ~What's the best way to send data across the "CPython gap"? That is, it is not obvious how to cleanly pass data from dynamo's `eval_frame.py:_TorchDynamoContext.__call__` to `symbolic_convert.py:InstructionTranslatorBase.__init__`. In this PR, we use a global.~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107067
Approved by: https://github.com/jansel
2023-09-11 17:11:51 +00:00
Jing Shan
fc2b980000 [Lint] Auto format graph_module.py (#108594)
Summary: Auto format the `graph_module.py` file

Test Plan: lint

Differential Revision: D48983066

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108594
Approved by: https://github.com/jiayisuse
2023-09-08 00:04:21 +00:00
Edward Z. Yang
666aeaa313 Preserve original co_filename when FX symbolic_trace (#103885)
Previously, you'd get `<eval_with_key>.0`; now you get `<eval_with_key>.0 from /data/users/ezyang/b/pytorch/test/dynamo/test_misc.py:5683 in forward`

I used to do this with globals, but now I do it with a `co_fields` parameter that's plumbed around, because putting things in globals has implications(TM). Happy to bikeshed on the `co_fields` structure.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103885
Approved by: https://github.com/albanD
2023-07-05 22:00:05 +00:00
Shiyan Deng
3c34a00d1b Preserve all submodules/parameters/buffers when unpickle graph module (#104115)
Summary:
When we pickle/unpickle graph module in multipy, we would lost modules/attributes that are not referred in the graph. This is because when unpickle fx graph module, we use the stored `__dict__` and the fx graph to create a new graph module. In GraphModule init, we drop any attribute that is not referred in the graph.

This behavior is not ideal because we actually expect a graph module that's exactly the same after unpickling.

Test Plan:
```
buck test mode/opt caffe2/test:fx -- test_preserve_unused_attr_after_unpickle

Tests finished: Pass 1. Fail 0. Fatal 0. Skip 0. Build failure 0
```

Differential Revision: D46976230

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104115
Approved by: https://github.com/houseroad
2023-06-26 06:59:48 +00:00
PyTorch MergeBot
29e3fddb08 Revert "Preserve original co_filename when FX symbolic_trace (#103885)"
This reverts commit b9f81a483a.

Reverted https://github.com/pytorch/pytorch/pull/103885 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/103885#issuecomment-1603612781))
2023-06-23 02:49:04 +00:00
Edward Z. Yang
b9f81a483a Preserve original co_filename when FX symbolic_trace (#103885)
Previously, you'd get `<eval_with_key>.0`; now you get `<eval_with_key>.0 from /data/users/ezyang/b/pytorch/test/dynamo/test_misc.py:5683 in forward`

I used to do this with globals, but now I do it with a `co_fields` parameter that's plumbed around, because putting things in globals has implications(TM). Happy to bikeshed on the `co_fields` structure.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103885
Approved by: https://github.com/albanD
2023-06-21 08:28:50 +00:00
Kazuaki Ishizaki
105ef68f72 Fix typos under torch/fx directory (#97596)
This PR fixes typos in comments and messages of `.py` files under `torch/fx` directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97596
Approved by: https://github.com/dagitses, https://github.com/kit1980
2023-04-10 21:57:36 +00:00
Jerry Zhang
2394e6baa9 [quant][fx] Change prepare_fx and convert_fx to preserve the GraphModule type of input (#94412)
Summary:
Previously prepare_fx returns an ObservedGraphModule and convert_fx returns a QuantizedGraphModule,
this is to preserve the attributes since torch.fx.GraphModule did not preserve them, after https://github.com/pytorch/pytorch/pull/92062
we are preserving `model.meta`, so we can store the attributes in model.meta now to preserve them.

With this, we don't need to create a new type of GraphModule in these functions and can use GraphModule directly, this
is useful for quantization in pytorch 2.0 flow, if other transformations are using GraphModule as well, the quantization passes will be composable with them

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
python test/test_quantization.py TestQuantizeFxModels
python test/test_quantization.py TestQuantizePT2E

Imported from OSS

Differential Revision: D42979722

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94412
Approved by: https://github.com/vkuzo
2023-02-09 23:03:23 +00:00
Aaron Gokaslan
8fce9a09cd [BE]: pyupgrade Python to 3.8 - imports and object inheritance only (#94308)
Apply parts of pyupgrade to torch (starting with the safest changes).
This PR only does two things: removes the need to inherit from object and removes unused future imports.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-07 21:10:56 +00:00
Han Qi
fc4e9931da [fx.GraphModule] Populate memo in deepcopy BEFORE copying children. (#93295)
Summary:
Apparently if not then at somepoint, we might lose fields if the submodules have circular reference

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93295
Approved by: https://github.com/jerryzh168
2023-01-31 01:45:35 +00:00
Han Qi
8d7f9e2f79 Make __deepcopy__ of GraphModule able to handle circular reference. (#93038)
Summary:
One of such places where circular reference can occur is: _load_state_dict_pre_hooks contains a _WrappedHook, _WrappedHook has a weakref to the same module.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93038
Approved by: https://github.com/jerryzh168
2023-01-27 01:19:59 +00:00
Han Qi (qihqi)
f0e3c4929b only copy meta if available (#92623)
Test Plan:
```
buck2 test mode/opt //torchmultimodal/tests:tests -- --exact 'torchmultimodal/tests:tests - test_albef.py::test_albef_image_embeddings_momentum'
```
now passes

Reviewed By: malfet

Differential Revision: D42608385

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92623
Approved by: https://github.com/tugsbayasgalan
2023-01-19 23:39:53 +00:00
Han Qi
00fe63d1d8 fx Graph should copy meta on deepcopy (#92062)
Summary:
fx Graph should copy meta on deepcopy

Test Plan:
Unit test

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92062
Approved by: https://github.com/zhxchen17
2023-01-18 02:49:14 +00:00
Zhengxu Chen
b7aa22d6db [fx] Fix GraphModule.print_readable() (#88730)
Summary: `__nested_code()` seems removed.

Test Plan: CI

Differential Revision: D41149662

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88730
Approved by: https://github.com/SherlockNoMad
2022-11-09 21:39:48 +00:00
Horace He
e150a6212b Added gm.print_readable to torchinductor_trace output (#87717)
cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87717
Approved by: https://github.com/ngimel
2022-10-25 22:31:49 +00:00
anjali411
a6c0442cce Add __all__ to torch.{autograd, fx, cuda} submodules (#85343)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85343
Approved by: https://github.com/albanD
2022-10-09 14:46:54 +00:00
Angela Yi
dd82b31e55 [fx] Add metadata to fx.GraphModule (#84378)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84378
Approved by: https://github.com/SherlockNoMad
2022-09-01 18:36:52 +00:00
Sherlock Huang
7e5c76da47 Make graph_module.print_readable() discoverable (#83960)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83960
Approved by: https://github.com/ezyang
2022-08-25 23:56:50 +00:00
Sherlock Huang
bf8d5e8328 Pretty print stack trace with gm.print_readable() (#83706)
Precondition: https://github.com/pytorch/torchdynamo/pull/899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None

    # No stacktrace found for following nodes
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None

    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region
```
def forward(self, a, b):

    # No stacktrace found for following nodes
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```

- make_fx without decomposition
```
def forward(self, a_1, b_1):

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83706
Approved by: https://github.com/Chillee, https://github.com/ezyang
2022-08-24 23:00:57 +00:00
Sergii Dymchenko
591222f5d9 Fix use-dict-literal lint (#83718)
Fix use-dict-literal pylint suggestions by changing `dict()` to `{}`. This PR should do the change for every Python file except test/jit/test_list_dict.py, where I think the intent is to test the constructor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83718
Approved by: https://github.com/albanD
2022-08-24 00:26:46 +00:00
Sherlock Huang
43e7fee764 [Reland] Recursively print graph module and its submodule (#81639)
ghstack-source-id: fcfc024c440981ee3fe3537a5816089eadf2cc13
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81080

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81639
Approved by: https://github.com/ezyang
2022-07-21 16:58:25 +00:00
PyTorch MergeBot
4035a53cca Revert "Recursively print graph module and its submodule (#81080)"
This reverts commit fe7262329c.

Reverted https://github.com/pytorch/pytorch/pull/81080 on behalf of https://github.com/DanilBaibak due to Break internal build
2022-07-18 14:46:26 +00:00
Sherlock Huang
fe7262329c Recursively print graph module and its submodule (#81080)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81080
Approved by: https://github.com/ezyang
2022-07-18 01:19:03 +00:00
Horace He
e7e835e50a Fix to folder by adding custom_builtins to dump (#81433)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81433
Approved by: https://github.com/jamesr66a
2022-07-14 21:39:13 +00:00
PyTorch MergeBot
58532256e9 Revert "Add __all__ for torch.distributed and fx modules (#80460)"
This reverts commit 5d40c3d5c8.

Reverted https://github.com/pytorch/pytorch/pull/80460 on behalf of https://github.com/malfet due to Broke MacOS testing, see https://github.com/pytorch/pytorch/runs/7105579664?check_suite_focus=true
2022-06-29 16:20:55 +00:00