Commit Graph

107 Commits

Author SHA1 Message Date
Aaron Gokaslan
cb856b08b2 [BE]: Attach cause to some exceptions and enable RUFF TRY200 (#111496)
Did some easy fixes from enabling TRY200. Most of these seem like oversights instead of intentional. The proper way to silence intentional errors is with `from None` to note that you thought about whether it should contain the cause and decided against it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111496
Approved by: https://github.com/malfet
2023-10-19 21:56:36 +00:00
Wenting Wang
675df7520a [tgif][multiforward] allow codegen to generate different func name (#111446)
Summary: see Shiyan's design doc for ATM TS publish weights dedupe https://fb.quip.com/HnUVAjUMaXMQ

Test Plan: tested in N4454041 after D50341352 that multiforward method is working for ts model

Differential Revision: D45750812

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111446
Approved by: https://github.com/842974287
2023-10-19 21:19:30 +00:00
Aaron Gokaslan
a0632389b7 [BE]: Update lintrunner mypy to 1.6.0 (#111375)
Follow up to #111305 that updates lintrunner's version too.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111375
Approved by: https://github.com/malfet
2023-10-17 01:22:06 +00:00
Sam Larsen
0dfa354570 [inductor] Implement Fx graph caching to improve warm compilation time. (#103453)
Summary: Implement an on-disk cache to save and reuse compiled FX Graphs. This implementation does not handle tensors with symbolic shapes. This needs to be done in a follow-up PR.

Test Plan:
* New unit tests exercising saving and load from the cache.
* New unit tests to exercise the cache key calculations.
* Ran several benchmarks to see cache hit and resulting compilation times.

Differential Revision: [D50255289](https://our.internmc.facebook.com/intern/diff/D50255289)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103453
Approved by: https://github.com/eellison, https://github.com/Chillee
2023-10-13 13:33:56 +00:00
PyTorch MergeBot
7fbfa4e020 Revert "[inductor] Implement Fx graph caching to improve warm compilation time. (#103453)"
This reverts commit fc1105b282.

Reverted https://github.com/pytorch/pytorch/pull/103453 on behalf of https://github.com/kit1980 due to Same issue unfortunately, the newly added test fails on internal builds ([comment](https://github.com/pytorch/pytorch/pull/103453#issuecomment-1760202365))
2023-10-12 18:54:51 +00:00
Sam Larsen
fc1105b282 [inductor] Implement Fx graph caching to improve warm compilation time. (#103453)
Summary: Implement an on-disk cache to save and reuse compiled FX Graphs. This implementation does not handle tensors with symbolic shapes. This needs to be done in a follow-up PR.

Test Plan:
* New unit tests exercising saving and load from the cache.
* New unit tests to exercise the cache key calculations.
* Ran several benchmarks to see cache hit and resulting compilation times.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103453
Approved by: https://github.com/eellison, https://github.com/Chillee
2023-10-11 14:39:14 +00:00
PyTorch MergeBot
3100d3e661 Revert "[inductor] Implement Fx graph caching to improve warm compilation time. (#103453)"
This reverts commit 8a8668e1ae.

Reverted https://github.com/pytorch/pytorch/pull/103453 on behalf of https://github.com/kit1980 due to The newly added test fails on internal builds ([comment](https://github.com/pytorch/pytorch/pull/103453#issuecomment-1756449919))
2023-10-10 23:21:59 +00:00
Sam Larsen
8a8668e1ae [inductor] Implement Fx graph caching to improve warm compilation time. (#103453)
Summary: Implement an on-disk cache to save and reuse compiled FX Graphs. This implementation does not handle tensors with symbolic shapes. This needs to be done in a follow-up PR.

Test Plan:
* New unit tests exercising saving and load from the cache.
* New unit tests to exercise the cache key calculations.
* Ran several benchmarks to see cache hit and resulting compilation times.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103453
Approved by: https://github.com/eellison
2023-10-08 20:32:15 +00:00
William Wen
b904432e82 [dynamo] preserve some FX node metadata of GraphModules (#107067)
Requested from @tugsbayasgalan: we want dynamo to preserve some FX node metadata when we trace `GraphModule`s (`nn_module_stack`, `source_fn`, `stack_trace`). This is helpful for the case when we export an aten-level `GraphModule`, add some (possibly non-torch or non-aten) ops, and we want to transform the graph back into an aten-level graph. Without preserving metadata, future passes that look at metadata (e.g. quantization passes) won't work.

This feature also has the additional benefit of being able to preserve origin line of code when `print_readable`'ing a `GraphModule`. This is helpful when debugging graphs that have passed through dynamo several times.

The added unit test demonstrates the added functionality of this PR.

~This PR is currently a proof-of-concept implementation that shows that preserving node metadata across dynamo is possible.~ This PR preserves node metadata across dynamo by doing the following:
- ~inject a counter variable into the `GraphModule` source code, which is incremented every time a node is run~
- Construct a line number -> node index map in `GraphModule` as the source code is being generated.
- pass a list of node metadata and the line number map to dynamo's bytecode analyzer
- ~dynamo traces the counter as a `ConstantVariable`, so when we create a new proxy, we can determine which original node index this proxy corresponds by looking at the value of the traced counter~
- When we create a new proxy, get the current instruction's line number, and get the node index using the line number map
- index into the original node metadata ~using the counter variable's tracked value.~

~Some things that should be addressed off the top of my head:~
- ~Is this feature even desirable? (Do we really want Dynamo to have special behavior for `GraphModules`? Should we expect users to re-export `GraphModules`?)~
- ~Is there a better approach than to use a counter? We considered using node names, line numbers, and assuming that proxies are created in the same order as the nodes, but each of these 3 have shortcomings. For node names, we only have access to new node names, not the old ones. Using line number is fragile. The third is problematic since not all created nodes go through `create_proxy` (e.g. inputs). We currently generate a line number to node index map when the `GraphModule`'s code is generated.~
- ~What's the best way to send data across the "CPython gap"? That is, it is not obvious how to cleanly pass data from dynamo's `eval_frame.py:_TorchDynamoContext.__call__` to `symbolic_convert.py:InstructionTranslatorBase.__init__`. In this PR, we use a global.~

Differential Revision: [D49257108](https://our.internmc.facebook.com/intern/diff/D49257108)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107067
Approved by: https://github.com/jansel
2023-09-15 23:29:14 +00:00
PyTorch MergeBot
c5e7588613 Revert "[dynamo] preserve some FX node metadata of GraphModules (#107067)"
This reverts commit 1d42148fee.

Reverted https://github.com/pytorch/pytorch/pull/107067 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/107067#issuecomment-1717321061))
2023-09-13 09:59:33 +00:00
William Wen
1d42148fee [dynamo] preserve some FX node metadata of GraphModules (#107067)
Requested from @tugsbayasgalan: we want dynamo to preserve some FX node metadata when we trace `GraphModule`s (`nn_module_stack`, `source_fn`, `stack_trace`). This is helpful for the case when we export an aten-level `GraphModule`, add some (possibly non-torch or non-aten) ops, and we want to transform the graph back into an aten-level graph. Without preserving metadata, future passes that look at metadata (e.g. quantization passes) won't work.

This feature also has the additional benefit of being able to preserve origin line of code when `print_readable`'ing a `GraphModule`. This is helpful when debugging graphs that have passed through dynamo several times.

The added unit test demonstrates the added functionality of this PR.

~This PR is currently a proof-of-concept implementation that shows that preserving node metadata across dynamo is possible.~ This PR preserves node metadata across dynamo by doing the following:
- ~inject a counter variable into the `GraphModule` source code, which is incremented every time a node is run~
- Construct a line number -> node index map in `GraphModule` as the source code is being generated.
- pass a list of node metadata and the line number map to dynamo's bytecode analyzer
- ~dynamo traces the counter as a `ConstantVariable`, so when we create a new proxy, we can determine which original node index this proxy corresponds by looking at the value of the traced counter~
- When we create a new proxy, get the current instruction's line number, and get the node index using the line number map
- index into the original node metadata ~using the counter variable's tracked value.~

~Some things that should be addressed off the top of my head:~
- ~Is this feature even desirable? (Do we really want Dynamo to have special behavior for `GraphModules`? Should we expect users to re-export `GraphModules`?)~
- ~Is there a better approach than to use a counter? We considered using node names, line numbers, and assuming that proxies are created in the same order as the nodes, but each of these 3 have shortcomings. For node names, we only have access to new node names, not the old ones. Using line number is fragile. The third is problematic since not all created nodes go through `create_proxy` (e.g. inputs). We currently generate a line number to node index map when the `GraphModule`'s code is generated.~
- ~What's the best way to send data across the "CPython gap"? That is, it is not obvious how to cleanly pass data from dynamo's `eval_frame.py:_TorchDynamoContext.__call__` to `symbolic_convert.py:InstructionTranslatorBase.__init__`. In this PR, we use a global.~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107067
Approved by: https://github.com/jansel
2023-09-11 17:11:51 +00:00
Jing Shan
fc2b980000 [Lint] Auto format graph_module.py (#108594)
Summary: Auto format the `graph_module.py` file

Test Plan: lint

Differential Revision: D48983066

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108594
Approved by: https://github.com/jiayisuse
2023-09-08 00:04:21 +00:00
Edward Z. Yang
666aeaa313 Preserve original co_filename when FX symbolic_trace (#103885)
Previously, you'd get `<eval_with_key>.0`; now you get `<eval_with_key>.0 from /data/users/ezyang/b/pytorch/test/dynamo/test_misc.py:5683 in forward`

I used to do this with globals, but now I do it with a `co_fields` parameter that's plumbed around, because putting things in globals has implications(TM). Happy to bikeshed on the `co_fields` structure.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103885
Approved by: https://github.com/albanD
2023-07-05 22:00:05 +00:00
Shiyan Deng
3c34a00d1b Preserve all submodules/parameters/buffers when unpickle graph module (#104115)
Summary:
When we pickle/unpickle graph module in multipy, we would lost modules/attributes that are not referred in the graph. This is because when unpickle fx graph module, we use the stored `__dict__` and the fx graph to create a new graph module. In GraphModule init, we drop any attribute that is not referred in the graph.

This behavior is not ideal because we actually expect a graph module that's exactly the same after unpickling.

Test Plan:
```
buck test mode/opt caffe2/test:fx -- test_preserve_unused_attr_after_unpickle

Tests finished: Pass 1. Fail 0. Fatal 0. Skip 0. Build failure 0
```

Differential Revision: D46976230

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104115
Approved by: https://github.com/houseroad
2023-06-26 06:59:48 +00:00
PyTorch MergeBot
29e3fddb08 Revert "Preserve original co_filename when FX symbolic_trace (#103885)"
This reverts commit b9f81a483a.

Reverted https://github.com/pytorch/pytorch/pull/103885 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/103885#issuecomment-1603612781))
2023-06-23 02:49:04 +00:00
Edward Z. Yang
b9f81a483a Preserve original co_filename when FX symbolic_trace (#103885)
Previously, you'd get `<eval_with_key>.0`; now you get `<eval_with_key>.0 from /data/users/ezyang/b/pytorch/test/dynamo/test_misc.py:5683 in forward`

I used to do this with globals, but now I do it with a `co_fields` parameter that's plumbed around, because putting things in globals has implications(TM). Happy to bikeshed on the `co_fields` structure.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103885
Approved by: https://github.com/albanD
2023-06-21 08:28:50 +00:00
Kazuaki Ishizaki
105ef68f72 Fix typos under torch/fx directory (#97596)
This PR fixes typos in comments and messages of `.py` files under `torch/fx` directory

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97596
Approved by: https://github.com/dagitses, https://github.com/kit1980
2023-04-10 21:57:36 +00:00
Jerry Zhang
2394e6baa9 [quant][fx] Change prepare_fx and convert_fx to preserve the GraphModule type of input (#94412)
Summary:
Previously prepare_fx returns an ObservedGraphModule and convert_fx returns a QuantizedGraphModule,
this is to preserve the attributes since torch.fx.GraphModule did not preserve them, after https://github.com/pytorch/pytorch/pull/92062
we are preserving `model.meta`, so we can store the attributes in model.meta now to preserve them.

With this, we don't need to create a new type of GraphModule in these functions and can use GraphModule directly, this
is useful for quantization in pytorch 2.0 flow, if other transformations are using GraphModule as well, the quantization passes will be composable with them

Test Plan:
python test/test_quantization.py TestQuantizeFx
python test/test_quantization.py TestQuantizeFxOps
python test/test_quantization.py TestQuantizeFxModels
python test/test_quantization.py TestQuantizePT2E

Imported from OSS

Differential Revision: D42979722

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94412
Approved by: https://github.com/vkuzo
2023-02-09 23:03:23 +00:00
Aaron Gokaslan
8fce9a09cd [BE]: pyupgrade Python to 3.8 - imports and object inheritance only (#94308)
Apply parts of pyupgrade to torch (starting with the safest changes).
This PR only does two things: removes the need to inherit from object and removes unused future imports.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-07 21:10:56 +00:00
Han Qi
fc4e9931da [fx.GraphModule] Populate memo in deepcopy BEFORE copying children. (#93295)
Summary:
Apparently if not then at somepoint, we might lose fields if the submodules have circular reference

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93295
Approved by: https://github.com/jerryzh168
2023-01-31 01:45:35 +00:00
Han Qi
8d7f9e2f79 Make __deepcopy__ of GraphModule able to handle circular reference. (#93038)
Summary:
One of such places where circular reference can occur is: _load_state_dict_pre_hooks contains a _WrappedHook, _WrappedHook has a weakref to the same module.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93038
Approved by: https://github.com/jerryzh168
2023-01-27 01:19:59 +00:00
Han Qi (qihqi)
f0e3c4929b only copy meta if available (#92623)
Test Plan:
```
buck2 test mode/opt //torchmultimodal/tests:tests -- --exact 'torchmultimodal/tests:tests - test_albef.py::test_albef_image_embeddings_momentum'
```
now passes

Reviewed By: malfet

Differential Revision: D42608385

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92623
Approved by: https://github.com/tugsbayasgalan
2023-01-19 23:39:53 +00:00
Han Qi
00fe63d1d8 fx Graph should copy meta on deepcopy (#92062)
Summary:
fx Graph should copy meta on deepcopy

Test Plan:
Unit test

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92062
Approved by: https://github.com/zhxchen17
2023-01-18 02:49:14 +00:00
Zhengxu Chen
b7aa22d6db [fx] Fix GraphModule.print_readable() (#88730)
Summary: `__nested_code()` seems removed.

Test Plan: CI

Differential Revision: D41149662

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88730
Approved by: https://github.com/SherlockNoMad
2022-11-09 21:39:48 +00:00
Horace He
e150a6212b Added gm.print_readable to torchinductor_trace output (#87717)
cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87717
Approved by: https://github.com/ngimel
2022-10-25 22:31:49 +00:00
anjali411
a6c0442cce Add __all__ to torch.{autograd, fx, cuda} submodules (#85343)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85343
Approved by: https://github.com/albanD
2022-10-09 14:46:54 +00:00
Angela Yi
dd82b31e55 [fx] Add metadata to fx.GraphModule (#84378)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84378
Approved by: https://github.com/SherlockNoMad
2022-09-01 18:36:52 +00:00
Sherlock Huang
7e5c76da47 Make graph_module.print_readable() discoverable (#83960)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83960
Approved by: https://github.com/ezyang
2022-08-25 23:56:50 +00:00
Sherlock Huang
bf8d5e8328 Pretty print stack trace with gm.print_readable() (#83706)
Precondition: https://github.com/pytorch/torchdynamo/pull/899

Given following function
```
def my_relu(a):
    return a.relu()

def func(a, b):
    d = torch.square(a + b)
    e = my_relu(d)
    f = d.sin()
    s = torch.stack([e, f])
    s = s.sum()
```

Here are the possible result with various tracing frontend: dynamo, symbolic_trace, make_fx
- joint graph with torchdynamo.optimize("aot_nop")
Notice that it has a special stack for gradient addition node (for multiple uses of tensor) in backward
Notice that "No stacktrace found for following nodes" are shown for nodes with stacktrace
```
def forward(self, primals, tangents):
    primals_1, primals_2, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None

    # No stacktrace found for following nodes
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    expand_default = torch.ops.aten.expand.default(tangents_1, [2, 10, 10]);  tangents_1 = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    unbind_int = torch.ops.aten.unbind.int(expand_default);  expand_default = None
    getitem = unbind_int[0]
    getitem_1 = unbind_int[1];  unbind_int = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    cos_default = torch.ops.aten.cos.default(pow_tensor_scalar);  pow_tensor_scalar = None
    mul_tensor = torch.ops.aten.mul.Tensor(getitem_1, cos_default);  getitem_1 = cos_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(getitem, detach_default_1, 0);  getitem = detach_default_1 = None

    # Gradient addition node due to mulitple use of tensor around:, File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    add_tensor_1 = torch.ops.aten.add.Tensor(mul_tensor, threshold_backward_default);  mul_tensor = threshold_backward_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 1.0);  add_tensor = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor_1 = torch.ops.aten.mul.Tensor(add_tensor_1, mul_scalar);  add_tensor_1 = mul_scalar = None
    sum_sym_int = torch.ops.aten.sum.SymInt(mul_tensor_1, [0], True)
    view_sym_int = torch.ops.aten.view.SymInt(sum_sym_int, [10]);  sum_sym_int = None
    return pytree.tree_unflatten([sum_default, mul_tensor_1, view_sym_int], self._out_spec)
```
- default symbolic_trace
Notice that nodes without stacktrace are folded under same region
```
def forward(self, a, b):

    # No stacktrace found for following nodes
    add = a + b;  a = b = None
    square = torch.square(add);  add = None
    relu = square.relu()
    sin = square.sin();  square = None
    stack = torch.stack([relu, sin]);  relu = sin = None
    sum_1 = stack.sum();  stack = None
    return sum_1
```
- symbolic_trace with record_stack_traces=True
```
def forward(self, a, b):

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add = a + b;  a = b = None
    square = torch.square(add);  add = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu = square.relu()

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin = square.sin();  square = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack = torch.stack([relu, sin]);  relu = sin = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_1 = stack.sum();  stack = None
    return sum_1
```

- make_fx without decomposition
```
def forward(self, a_1, b_1):

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    add_tensor = torch.ops.aten.add.Tensor(a_1, b_1);  a_1 = b_1 = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(add_tensor, 2);  add_tensor = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar)
    detach_default = torch.ops.aten.detach.default(relu_default)

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.aten.sin.default(pow_tensor_scalar);  pow_tensor_scalar = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    stack_default = torch.ops.aten.stack.default([relu_default, sin_default]);  relu_default = sin_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    sum_default = torch.ops.aten.sum.default(stack_default);  stack_default = None
    return sum_default
```
- make_fx with decomposition to prims
```
def forward(self, a_1, b_1):

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 41, in func, d = torch.square(a + b)
    broadcast_in_dim_default = torch.ops.prims.broadcast_in_dim.default(b_1, [10, 10], [1]);  b_1 = None
    add_default = torch.ops.prims.add.default(a_1, broadcast_in_dim_default);  a_1 = broadcast_in_dim_default = None
    mul_default = torch.ops.prims.mul.default(add_default, add_default);  add_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 38, in my_relu, return a.relu()
    le_default = torch.ops.prims.le.default(mul_default, 0.0)
    where_default = torch.ops.prims.where.default(le_default, 0.0, mul_default);  le_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 43, in func, f = d.sin()
    sin_default = torch.ops.prims.sin.default(mul_default);  mul_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 44, in func, s = torch.stack([e, f])
    cat_default = torch.ops.prims.cat.default([where_default, sin_default], 0);  where_default = sin_default = None
    split_dim_default = torch.ops.prims.split_dim.default(cat_default, 0, 2);  cat_default = None

    # File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 45, in func, s = s.sum()
    convert_element_type_default = torch.ops.prims.convert_element_type.default(split_dim_default, torch.float32);  split_dim_default = None
    sum_default = torch.ops.prims.sum.default(convert_element_type_default, [0, 1, 2]);  convert_element_type_default = None
    return sum_default
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83706
Approved by: https://github.com/Chillee, https://github.com/ezyang
2022-08-24 23:00:57 +00:00
Sergii Dymchenko
591222f5d9 Fix use-dict-literal lint (#83718)
Fix use-dict-literal pylint suggestions by changing `dict()` to `{}`. This PR should do the change for every Python file except test/jit/test_list_dict.py, where I think the intent is to test the constructor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83718
Approved by: https://github.com/albanD
2022-08-24 00:26:46 +00:00
Sherlock Huang
43e7fee764 [Reland] Recursively print graph module and its submodule (#81639)
ghstack-source-id: fcfc024c440981ee3fe3537a5816089eadf2cc13
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81080

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81639
Approved by: https://github.com/ezyang
2022-07-21 16:58:25 +00:00
PyTorch MergeBot
4035a53cca Revert "Recursively print graph module and its submodule (#81080)"
This reverts commit fe7262329c.

Reverted https://github.com/pytorch/pytorch/pull/81080 on behalf of https://github.com/DanilBaibak due to Break internal build
2022-07-18 14:46:26 +00:00
Sherlock Huang
fe7262329c Recursively print graph module and its submodule (#81080)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81080
Approved by: https://github.com/ezyang
2022-07-18 01:19:03 +00:00
Horace He
e7e835e50a Fix to folder by adding custom_builtins to dump (#81433)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81433
Approved by: https://github.com/jamesr66a
2022-07-14 21:39:13 +00:00
PyTorch MergeBot
58532256e9 Revert "Add __all__ for torch.distributed and fx modules (#80460)"
This reverts commit 5d40c3d5c8.

Reverted https://github.com/pytorch/pytorch/pull/80460 on behalf of https://github.com/malfet due to Broke MacOS testing, see https://github.com/pytorch/pytorch/runs/7105579664?check_suite_focus=true
2022-06-29 16:20:55 +00:00
anjali411
5d40c3d5c8 Add __all__ for torch.distributed and fx modules (#80460)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80460
Approved by: https://github.com/albanD, https://github.com/rohan-varma
2022-06-29 02:53:56 +00:00
James Reed
7311390d35 [WIP] Make constructor calls in experimental MetaTracer serializable
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76789

Approved by: https://github.com/pbelevich
2022-05-11 00:19:47 +00:00
Michael Suo
fb0f285638 [lint] upgrade mypy to latest version
Fixes https://github.com/pytorch/pytorch/issues/75927.

Had to fix some bugs and add some ignores.

To check if clean:
```
lintrunner --paths-cmd='git grep -Il .' --take MYPY,MYPYSTRICT
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76753

Approved by: https://github.com/malfet
2022-05-03 20:51:34 +00:00
PyTorch MergeBot
3d7428d9ac Revert "[lint] upgrade mypy to latest version"
This reverts commit 9bf18aab94.

Reverted https://github.com/pytorch/pytorch/pull/76753 on behalf of https://github.com/suo
2022-05-03 20:01:18 +00:00
Michael Suo
9bf18aab94 [lint] upgrade mypy to latest version
Fixes https://github.com/pytorch/pytorch/issues/75927.

Had to fix some bugs and add some ignores.

To check if clean:
```
lintrunner --paths-cmd='git grep -Il .' --take MYPY,MYPYSTRICT
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76753

Approved by: https://github.com/malfet
2022-05-03 19:43:28 +00:00
James Reed
bf730e5039 Fix unnecessary recursion in GraphModule.__call__
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76068

Approved by: https://github.com/Chillee
2022-04-21 03:25:39 +00:00
Pavel Belevich
96c8f64459 Remove with_traceback(None) in wrapped_call to show the root cause error
Before:
```
Traceback (most recent call last):
  File "/Users/pbelevich/PycharmProjects/PiPPy/test/t5_test.py", line 37, in <module>
    t5_pipe_output = t5_pipe(input_ids=t5_input, decoder_attention_mask=None, decoder_input_ids=decoder_input_ids)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 251, in forward
    return self.executor.run(*executor_args)
  File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 155, in run
    return super().run(*args, initial_env=initial_env)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 121, in run
    self.env[node] = self.run_node(node)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 148, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 170, in call_module
    return super().call_module(target, args, kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 265, in call_module
    return submod(*args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e.with_traceback(None)
AttributeError: 'NoneType' object has no attribute 'dtype'
```
After:
```
Traceback (most recent call last):
  File "/Users/pbelevich/PycharmProjects/PiPPy/test/t5_test.py", line 37, in <module>
    t5_pipe_output = t5_pipe(input_ids=t5_input, decoder_attention_mask=None, decoder_input_ids=decoder_input_ids)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 251, in forward
    return self.executor.run(*executor_args)
  File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 155, in run
    return super().run(*args, initial_env=initial_env)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 121, in run
    self.env[node] = self.run_node(node)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 148, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 170, in call_module
    return super().call_module(target, args, kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 265, in call_module
    return submod(*args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
    raise e
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 622, in wrapped_call
    return super(cls, self).__call__(*args, **kwargs)
  File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "<eval_with_key>.42", line 74, in forward
  File "/Users/pbelevich/PycharmProjects/pbelevich-transformers/src/transformers/utils/fx.py", line 180, in wrapper
    return func(*args, **kwargs)
  File "/Users/pbelevich/PycharmProjects/pbelevich-transformers/src/transformers/modeling_utils.py", line 256, in create_extended_attention_mask_for_decoder
    causal_mask = causal_mask.to(attention_mask.dtype)
AttributeError: 'NoneType' object has no attribute 'dtype'
```

The last lines of stack trace show where the problem is
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74655
Approved by: https://github.com/ansley, https://github.com/rohan-varma
2022-03-25 14:40:45 +00:00
Animesh Jain
7ebab9247d FX graph module - prevent infinite recursion (#73866)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73866

super(type(self), self) in wrapped_call leads to infinite recursion for subclass of Fx graph module. This happens when we call _stateless.functional_call on a fx module.     https://github.com/pytorch/pytorch/blob/master/torch%2Fnn%2Futils%2F_stateless.py

Test Plan:
Tests added in https://github.com/pytorch/pytorch/pull/62436

Imported from OSS

Reviewed By: jansel

Differential Revision: D34737828

fbshipit-source-id: 871b897e1210173ccc83fe34d53fc41af00a39ee
(cherry picked from commit 3d0c5fc71503fa2782b497a9d39ce26288fd219b)
2022-03-09 06:09:57 +00:00
Horace He
d635d0f86e Refactor FX codegen into extensible Codegen object (#72566)
Summary:
The goal of this is to make FX's codegen extensible. I've refactored it into a class with 5 extensibility points on it.

```
class Codegen(object):
    def generate_prologue(self, free_vars: List[str], maybe_return_annotation: str) -> str:
        """
        Given the free variables and a return annotation, generates the beginning of the FX function.
        By default, `generate_prologue(['a', 'b'], '') == 'def forward(a, b):'`
        """
    def generate_output(self, output_args: Argument) -> str:
        """
        Given the output arguments, generates the return statement of the FX function.
        """
    def process_inputs(self, args: Any) -> Any:
        """
        Transforms the inputs so that the graph can take them as arguments, as
        non-default codegen may result in the inputs to the function being
        different from the inputs to the graph.

        If the graph was directly runnable, this invariant should hold true
        `f.process_outputs(f.graph(*f.process_inputs(*inputs))) == f(*inputs)`
        """
    def process_outputs(self, outputs: Any) -> Any:
        """
        Transforms the outputs of the graph to be identical to the codegen.

        See ``process_inputs`` for more details.
        """
    def additional_globals(self) -> List[Tuple[str, Any]]:
        """
        If your codegen uses extra global values, add them here.
        For example, return ['List', typing.List] if you need ``List`` in the global context.
        """
```

So, for example, the `ListCodeGen` we want for AOTAutograd looks like this
```
        class ListCodeGen(CodeGen):
            def generate_prologue(self, free_vars, maybe_return_annotation):
                lst_unpack = f"""
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
    {', '.join(free_vars)} = args_list"""
                return lst_unpack

            def additional_globals(self):
                return [('List', typing.List)]

            def process_inputs(self, *inputs):
                assert(len(inputs) == 1)
                return inputs[0]
```
and
```
        def f(a, b):
            return a + b

        nf = fx.symbolic_trace(f)
        nf.graph.set_codegen(ListCodeGen())
        nf.recompile()
        print(nf.code)
```
would result in
```
def forward(self, args_list: List[torch.Tensor]):
    a, b = args_list
    add = a + b;  a = b = None
    return add
```

Backwards compatibility changes - I added `process_outputs` and `process_inputs` to `fx.Graph`, while removing `flatten_inputs` and `flatten_outputs` - those didn't have `backwards_compatibility` on them, so I *think* it's probably fine?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72566

Reviewed By: desertfire

Differential Revision: D34160424

Pulled By: Chillee

fbshipit-source-id: ebf6411312b373e3fbcb13288a34befa449a2375
(cherry picked from commit 13cd12eaa1)
2022-02-11 18:13:29 +00:00
Horace He
df6eb9bbab Fixed to_folder not saving dtype (#69983)
Summary:
As above.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/69983

Reviewed By: pbelevich, ngimel

Differential Revision: D33466529

Pulled By: Chillee

fbshipit-source-id: 2d2f0ad5b8e2492aba4c19fa034c8b6c0848a568
2022-01-06 22:15:56 -08:00
James Reed
3eb9443619 [FX] Fix issue where GraphModule.delete_all_unused_submodules deletes submodules from called leaf modules (#66430)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66430

On the whole, I'm not totally satisfied with this approach. I think we should be building a prefix tree data structure during initial iteration over the submodules and querying that when deleting submodules. But I think this approach works and I want to see if we can get it in before 1.10

Test Plan: Imported from OSS

Reviewed By: Chillee

Differential Revision: D31546137

Pulled By: jamesr66a

fbshipit-source-id: f08b8409a3cf511277017ccccb916097b7c4c4fe
2021-10-11 19:37:51 -07:00
James Reed
0c4e4e588e [FX] Rename reduce functions back to their old, public names (#64324)
Summary:
Unfortunately pickle serializes the names of these functions. Also put them under backward-compatibility enforcement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64324

Test Plan: Local repro https://fb.workplace.com/groups/3440841732711443/permalink/4018921611570116/

Reviewed By: SplitInfinity, TailofJune

Differential Revision: D30684185

Pulled By: jamesr66a

fbshipit-source-id: 900701220155d15115cd0c07cf7774a2891bd04f
2021-08-31 22:36:11 -07:00
Jay Leverett
44fcb00a56 Fix redundant class definition in GraphModule singleton constructor (#64274)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/63883

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64274

Reviewed By: jamesr66a

Differential Revision: D30675970

Pulled By: jayleverett

fbshipit-source-id: e74ef2a28013f0fa7c58d14f38e66cfe48d26b74
2021-08-31 17:34:14 -07:00
James Reed
538647fe1f [WIP][FX] BC guarantees for 1.10 (#63888)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63888

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D30523133

Pulled By: jamesr66a

fbshipit-source-id: b04cc0d842a74862f42ecba98b757310cd2ec7b0
2021-08-30 19:56:46 -07:00
James Reed
4e37a015c7 [FX] Fix _replicate_for_data_parallel (#63821)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63821

Test Plan: Imported from OSS

Reviewed By: suo

Differential Revision: D30502115

Pulled By: jamesr66a

fbshipit-source-id: 0f004f95def6e1ba21ccbeab40cb0a739a0ad20c
2021-08-24 13:48:15 -07:00