Commit Graph

192 Commits

Author SHA1 Message Date
Riley Dulin
24482e5c68 [torch][fx] Set maximum warning count during fx.Graph.lint (#135069)
Summary:
resnet152 spent about 15 minutes writing warning messages in _unlift
during `to_executorch` because they're all written to unbuffered stderr
by the `warnings` module.

These warnings are almost always about get_attr nodes referencing a
non-existent name:
```lang=py
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
  'not reference an nn.Module, nn.Parameter, or buffer, which is '
  'what \'get_attr\' Nodes typically target'
)
```
I'm not aware of a way to configure the warnings module to write this out
at most once, so I'm just going to disable the lint for now.

Test Plan:
Re-ran resnet152 with Executorch and the XNNPackBackend, it is much faster now

Differential Revision: D62156090

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135069
Approved by: https://github.com/yushangdi
2024-09-06 16:41:59 +00:00
Jason Ansel
66da3b3b2a [fx] Bypass custom __setattr__ in Node.__init__ (#135079)
Before:
![image](https://github.com/user-attachments/assets/5f0a6ae6-6049-44d0-b5f2-a549a23ad97f)

After:
![image](https://github.com/user-attachments/assets/51c9f91b-f8a0-4043-8362-65813feec823)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135079
Approved by: https://github.com/oulgen
ghstack dependencies: #135070, #135076, #135082, #135084
2024-09-06 06:11:46 +00:00
Aaron Orenstein
ed86ac2f25 [BE] typing for decorators - fx/_compatibility (#134054)
Summary: See #131429

Test Plan: unit tests pass

Differential Revision: D61493706

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134054
Approved by: https://github.com/oulgen
2024-08-26 04:00:27 +00:00
Jiashen Cao
fa8c34301a [ts-migration]: Quantized ops to standard ops pass. (#133026)
#### Description
Transform quantized operation properly. Add de/quantization before and after the quantized operation.

#### Test Plan
`pytest test/export/test_converter.py -s -k test_ts2ep_convert_quantized_model`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133026
Approved by: https://github.com/angelayi
2024-08-08 23:10:17 +00:00
PyTorch MergeBot
6f99e97f0a Revert "[ts-migration]: Support quantized operation transformation (#131915)"
This reverts commit 0e8541766f.

Reverted https://github.com/pytorch/pytorch/pull/131915 on behalf of https://github.com/ezyang due to test broken on windows 0e8541766f ([comment](https://github.com/pytorch/pytorch/pull/131915#issuecomment-2275974907))
2024-08-08 14:30:35 +00:00
Jiashen Cao
0e8541766f [ts-migration]: Support quantized operation transformation (#131915)
#### Description
Transform quantized operation properly. Add de/quantization before and after the quantized operation.

#### Test Plan
`pytest test/export/test_converter.py -s -k test_ts2ep_convert_quantized_model`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131915
Approved by: https://github.com/angelayi
2024-08-08 06:34:53 +00:00
David Berard
5973aec671 [fx] python_code(verbose=True): show size/strides for all tensors (#132192)
python_code(verbose=True) (or print_readable()) generates a string with the code representing the fx graph, with extra annotations indicating the size or stride of the tensor. Currently, it'll only shows sizes/strides for FakeTensors provided in metadata. For subclass tensors like NestedTensor, the outer class (provided in the node metadata) will be a non-FakeTensor and the inner tensors will be fake. This PR expands the conditional to show sizes/strides for all tensors, not just FakeTensors.

Testing: I ran this test script (below), ran it with `TORCH_LOGS=+dynamo` and found in the logs the graph shown below - we see that the input nested tensor has sizes and strides associated with it. Also, I stacked a diff on top of this one that forces the readable graph to be generated whenever PT2 is in use in tests, which should hopefully find any issues; https://github.com/pytorch/pytorch/pull/132195 shows no significant failures except for preexisting failures.

test script:
```python
import torch

def fn(x):
    return x.cos()

nt = torch.nested.nested_tensor_from_jagged(
    torch.randn(10, 10),
    torch.tensor([0, 1, 3, 6, 10]),
)

torch.compile(fn)(nt)
```

logs excerpt:
```
[0/0] [__graph_code] TRACED GRAPH
[0/0] [__graph_code]  ===== __compiled_fn_1 =====
[0/0] [__graph_code]  /data/users/dberard/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.M

[0/0] [__graph_code]     def forward(self, L_x_: "f32[4, zf1, 10][10*zf1, 10, 1]cpu", zf1: "Sym(zf1)"):
[0/0] [__graph_code]         l_x_ = L_x_
[0/0] [__graph_code]
[0/0] [__graph_code]          # File: /data/users/dberard/scripts/nt_print_graph.py:4 in fn, code: return x.c

[0/0] [__graph_code]         cos: "f32[4, zf1, 10][10*zf1, 10, 1]cpu" = l_x_.cos();  l_x_ = None
[0/0] [__graph_code]         return (cos,)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132192
Approved by: https://github.com/Chillee
2024-08-03 02:54:32 +00:00
YangQun1
589aef4bb0 Fix py codegen to delete values that don't have any users (#131028)
Fixes #131025

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131028
Approved by: https://github.com/ezyang
2024-08-01 03:18:37 +00:00
PyTorch MergeBot
945bf78894 Revert "[BE] typing for decorators - fx/_compatibility (#131568)"
This reverts commit 193f62fde9.

Reverted https://github.com/pytorch/pytorch/pull/131568 on behalf of https://github.com/clee2000 due to same as https://github.com/pytorch/pytorch/pull/131572#issuecomment-2254328359 but I clicked the wrong link by accident.  This is where it actually starts ([comment](https://github.com/pytorch/pytorch/pull/131568#issuecomment-2254330781))
2024-07-28 03:43:39 +00:00
Aaron Orenstein
193f62fde9 [BE] typing for decorators - fx/_compatibility (#131568)
See #131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131568
Approved by: https://github.com/justinchuby, https://github.com/oulgen, https://github.com/zou3519
2024-07-25 22:24:19 +00:00
PyTorch MergeBot
c3679bed35 Revert "Fix py codegen to delete values that don't have any users (#131028)"
This reverts commit 91aba7baac.

Reverted https://github.com/pytorch/pytorch/pull/131028 on behalf of https://github.com/clee2000 due to broke inductor/test_triton_kernels inductor/test_triton_kernels.py::KernelTests::test_triton_kernel_functionalize [GH job link](https://github.com/pytorch/pytorch/actions/runs/10094659640/job/27915271250) [HUD commit link](91aba7baac) ([comment](https://github.com/pytorch/pytorch/pull/131028#issuecomment-2251058374))
2024-07-25 17:42:18 +00:00
YangQun1
91aba7baac Fix py codegen to delete values that don't have any users (#131028)
Fixes #131025

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131028
Approved by: https://github.com/ezyang
2024-07-25 13:04:23 +00:00
PyTorch MergeBot
8ffd109a00 Revert "Fix py codegen to delete values that don't have any users (#131028)"
This reverts commit 466c167b71.

Reverted https://github.com/pytorch/pytorch/pull/131028 on behalf of https://github.com/atalman due to breaks CI ([comment](https://github.com/pytorch/pytorch/pull/131028#issuecomment-2247771530))
2024-07-24 12:21:43 +00:00
YangQun1
466c167b71 Fix py codegen to delete values that don't have any users (#131028)
Fixes #131025

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131028
Approved by: https://github.com/ezyang
2024-07-24 01:03:56 +00:00
Aaron Orenstein
5a0068cc69 [BE] mypy: disallow untyped decorators (#131428)
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations.

Step 1 - Enable the error and override in all the offending files.

#131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428
Approved by: https://github.com/justinchuby, https://github.com/oulgen
2024-07-23 21:50:55 +00:00
Shangdi Yu
ea4b80e6d6 [FX][export] strict DCE pass, check schema for node impurity (#130552)
Fixes the failure in `test/export/test_export_training_ir_to_run_decomp.py ` caused by dead code elimination removing node with side effects.

For background, in export, we may want to export higher-level IRs that are not functional, so we need to check for side effects more carefully.

 A call_function node is impure if it has at least one mutable argument.

Fixed the tests below:

test_to_module_with_mutated_buffer_multiple_update_sub_later
test_export_input_mutation_static_shape
test_buffer_util

Another attempt modifying the original DCE pass is made in PR #130395, but it breaks some other tests, so here we add a flag and use it for export only.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130552
Approved by: https://github.com/pianpwk
2024-07-12 15:43:27 +00:00
Chen Lai
721a798886 add bits16 to graph dtype_abbrs (#130339)
As title, patch the dtype in torch.fx.graph
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130339
Approved by: https://github.com/angelayi
2024-07-09 19:58:51 +00:00
angelayi
e9c6e8369c Torchbind call method + effects support (#128397)
Adds effect token support to torchbind method calls by allowing `with_effects` to take in `torch.ops._higher_order_ops.call_torchbind` as an input.

Here is the print from `TORCH_LOGS="aot" python test/export/test_torchbind.py -k test_compile_obj_torchbind_op`:
```python
def forward(self, arg0_1: "f32[0]", arg1_1: "f32[2]", arg2_1):
    # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1266 in f, code: torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
    cos: "f32[2]" = torch.ops.aten.cos.default(arg1_1)
    with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops._TorchScriptTesting.queue_push.default, arg2_1, cos);  arg0_1 = cos = None
    getitem: "f32[0]" = with_effects[0];  with_effects = None

    # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1267 in f, code: torch.ops._TorchScriptTesting.queue_push(tq, x.cos() + 1)
    cos_1: "f32[2]" = torch.ops.aten.cos.default(arg1_1)
    add: "f32[2]" = torch.ops.aten.add.Tensor(cos_1, 1);  cos_1 = None
    with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops._TorchScriptTesting.queue_push.default, arg2_1, add);  getitem = add = None
    getitem_2: "f32[0]" = with_effects_1[0];  with_effects_1 = None

    # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1268 in f, code: torch.ops._TorchScriptTesting.queue_pop(tq)
    with_effects_2 = torch._higher_order_ops.effects.with_effects(getitem_2, torch.ops._TorchScriptTesting.queue_pop.default, arg2_1);  getitem_2 = None
    getitem_4: "f32[0]" = with_effects_2[0];  with_effects_2 = None

    # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1269 in f, code: torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
    sin: "f32[2]" = torch.ops.aten.sin.default(arg1_1);  arg1_1 = None
    with_effects_3 = torch._higher_order_ops.effects.with_effects(getitem_4, torch.ops._TorchScriptTesting.queue_push.default, arg2_1, sin);  getitem_4 = sin = None
    getitem_6: "f32[0]" = with_effects_3[0];  with_effects_3 = None

    # File: /data/users/angelayi/pytorch2/test/export/test_torchbind.py:1270 in f, code: return tq.pop(), tq.pop() + tq.size(), tq
    with_effects_4 = torch._higher_order_ops.effects.with_effects(getitem_6, torch.ops._higher_order_ops.call_torchbind, arg2_1, 'pop');  getitem_6 = None
    getitem_8: "f32[0]" = with_effects_4[0]
    getitem_9: "f32[2]" = with_effects_4[1];  with_effects_4 = None
    with_effects_5 = torch._higher_order_ops.effects.with_effects(getitem_8, torch.ops._higher_order_ops.call_torchbind, arg2_1, 'pop');  getitem_8 = None
    getitem_10: "f32[0]" = with_effects_5[0]
    getitem_11: "f32[2]" = with_effects_5[1];  with_effects_5 = None
    with_effects_6 = torch._higher_order_ops.effects.with_effects(getitem_10, torch.ops._higher_order_ops.call_torchbind, arg2_1, 'size');  getitem_10 = arg2_1 = None
    getitem_12: "f32[0]" = with_effects_6[0];  with_effects_6 = None
    add_1: "f32[2]" = torch.ops.aten.add.Tensor(getitem_11, 0);  getitem_11 = None
    return (getitem_12, getitem_9, add_1)
```

In order to support this, this PR makes the following changes:
* Adds `FakeScriptObject` to `CustomObjArgument`, which will be put on the `meta["val"]` of nodes representing torchbind objects.
* Adds pickle/deepcopy support to FunctionSchema.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128397
Approved by: https://github.com/ydwu4, https://github.com/zou3519
2024-06-14 21:28:17 +00:00
chilli
c486e2ab64 Add coloring to fx graph print out (#128476)
Note: Won't land immediately, at least I'll need to add a color option to the field. But curious if any tests fail.

Old:
<img width="1294" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/c3a750ed-5e54-4621-b2e4-be5481be15b6">

New:
<img width="1303" alt="image" src="https://github.com/pytorch/pytorch/assets/6355099/3a1f1adc-6f3a-413e-8b87-ee53da9bf4ed">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128476
Approved by: https://github.com/ezyang
2024-06-13 23:39:04 +00:00
Oguz Ulgen
5b5d269d34 Speed up fx graph iteration by implementing it in C++ (#128288)
Before this change
```
python benchmarks/dynamo/microbenchmarks/fx_microbenchmarks.py
iterating over 100000000 FX nodes took 19.5s (5132266 nodes/s)
```

After this change
```
python benchmarks/dynamo/microbenchmarks/fx_microbenchmarks.py
iterating over 100000000 FX nodes took 3.4s (29114001 nodes/s)
```

5.7x improvement

Differential Revision: [D58343997](https://our.internmc.facebook.com/intern/diff/D58343997)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128288
Approved by: https://github.com/jansel, https://github.com/albanD
2024-06-11 05:48:31 +00:00
Aaron Orenstein
038b927590 Flip default value for mypy disallow_untyped_defs [7/11] (#127844)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127844
Approved by: https://github.com/oulgen
ghstack dependencies: #127842, #127843
2024-06-08 18:49:45 +00:00
Xuehai Pan
8b08b0f340 [BE] enable ruff rule Q from flake8-quotes (#127713)
Enable [ruff rule `Q`](https://docs.astral.sh/ruff/rules/#flake8-quotes-q) from flake8-quotes. Fixes:

- [avoidable-escaped-quote (Q003)](https://docs.astral.sh/ruff/rules/avoidable-escaped-quote/#avoidable-escaped-quote-q003)
- [unnecessary-escaped-quote (Q004)](https://docs.astral.sh/ruff/rules/unnecessary-escaped-quote/#unnecessary-escaped-quote-q004)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127713
Approved by: https://github.com/ezyang
2024-06-02 23:25:26 +00:00
Sheng Fu
bbeb0906c4 Register creak_node_hook (#126671)
Differential Revision: D57469157

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126671
Approved by: https://github.com/angelayi
2024-05-24 23:32:15 +00:00
angelayi
8be4c1bc2f [export] Add metadata for nodes insert_deferred_runtime_asserts (#125414)
Fixes [internal error](https://fb.workplace.com/groups/1075192433118967/permalink/1416709435633930/).

The issue is that the asserting nodes added in the `insert_deferred_runtime_assertion` pass do not contain metadata that the ExportedProgram requires the graph to have. One solution to fix this is to retrace the entire module, or another solution is to manually add back this metadata.

This diff implements the latter solution (manually add back the metadata) through hooking into fx.graph's `create_node` function, and adding export-specific metadata for every node that is created. The reason I did this is so that the `insert_deferred_runtime_assertion` does not have to know about what metadata export wants.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125414
Approved by: https://github.com/zhxchen17, https://github.com/BoyuanFeng
2024-05-07 23:15:21 +00:00
Edward Z. Yang
ecd62746e3 Also pull size/stride info from example_value (#125505)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125505
Approved by: https://github.com/jansel
2024-05-05 22:27:46 +00:00
Simon Fan
43a7ab2a21 [compiled autograd] introduce verbose logs, add autograd node info to graph (#124954)
- sets it as a fake stack trace as we don't have a generic comment feature
- when verbose is disabled, still adds a contextmanager and flag checks. the alternative is to use MACROS, but that wouldn't be usable with TORCH_LOGS

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124954
Approved by: https://github.com/jansel
2024-04-27 01:10:37 +00:00
Sherlock Huang
c2f687f32c Option to include stride and device annotation in gm.print_readable() (#123690)
Summary:
Sample output for gm.print_readable(include_stride=True, include_device=True)

```
        getitem_21: "i32[1200][1]cuda:0" = auto_functionalized_4[1]
        copy_2: "f32[2, 60][60, 1]cuda:1"  = ....
```

Test Plan: CI

Differential Revision: D55949129

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123690
Approved by: https://github.com/Chillee
2024-04-11 06:53:10 +00:00
Oguz Ulgen
03b13851d9 [FX] Add side table to FX Graph for O(1) op/target query (#121565)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121565
Approved by: https://github.com/jansel
2024-04-07 18:51:05 +00:00
Oguz Ulgen
7c5e29ae71 Back out "Support triton.language.dtype with torch.compile (#121690)" (#122108)
Summary: Some hard to deal with package import/export related problems. Lets revert and start with clean slate.

Test Plan: CI

Differential Revision: D55024877

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122108
Approved by: https://github.com/ezyang
2024-03-18 20:50:28 +00:00
Oguz Ulgen
65ccac6f17 Fix triton import time cycles (#122059)
Summary: `has_triton` causes some import time cycles. Lets use `has_triton_package` which is enough.

Test Plan:
```
buck2 test 'fbcode//mode/opt' fbcode//fblearner/flow/projects/model_processing/pytorch_model_export_utils/logical_transformations/tests:filter_inference_feature_metadata_test -- --exact 'fblearner/flow/projects/model_processing/pytorch_model_export_utils/logical_transformations/tests:filter_inference_feature_metadata_test - test_collect_features_from_graph_module_nodes (fblearner.flow.projects.model_processing.pytorch_model_export_utils.logical_transformations.tests.filter_inference_feature_metadata_test.FilterInferenceFromFeatureMetadataTest)'
```
now passes

Differential Revision: D55001430

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122059
Approved by: https://github.com/aakhundov
2024-03-18 05:50:32 +00:00
Oguz Ulgen
e39aedfcc5 Fix fx graph triton import bug (#122041)
Summary: Unless we register triton to be a special import, FX graph import mechanism imports it as `from fx-generated._0 import triton as triton` which is obviously broken.

Test Plan:
I could not figure out how to write a test for this but
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//tgif/lib/tests/gpu_tests:lowering_pass_test -- -r test_default_ait_lowering_multi_hardwares
```
now passes

Differential Revision: D54990782

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122041
Approved by: https://github.com/aakhundov
2024-03-17 22:48:51 +00:00
Oguz Ulgen
79ee6bbde3 Support triton.language.dtype with torch.compile (#121690)
Putting this PR as an RFC since I have resorted to some horrible hacks in order to make this work.
```
(Pdb) p triton.language.float32
triton.language.fp32
(Pdb) p str(triton.language.float32)
'fp32'
(Pdb) p repr(triton.language.float32)
'triton.language.fp32'
```
This means that we need to "rewrite" them for fx graph and inductor execution.

This PR allows Mamba2 to work with `torch.compile`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121690
Approved by: https://github.com/Skylion007
2024-03-12 23:21:46 +00:00
Jason Ansel
9aa3fedb75 Slightly faster FX graph iterator (#121611)
Before:
```
iterating over 100000000 FX nodes took 5.9s (16830686 nodes/s)
```

After:
```
iterating over 100000000 FX nodes took 5.0s (19937698 nodes/s)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121611
Approved by: https://github.com/oulgen
2024-03-11 20:00:19 +00:00
Oguz Ulgen
660ec3d38d [Export] Fix bug removing node from wrong graph (#121574)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121574
Approved by: https://github.com/ydwu4
2024-03-10 04:46:11 +00:00
Jiong Gong
896cf9d1ce [inductor][cpp] vectorization support for int32/int64 (#119001)
This pull request aims to complete most of the support for vectorizing int32 and int64 data types except for indirect indexing and masks. The basic data type support for uint32 and uint64 is also added but without vectorization. More vectorized conversion functions are added between integer and float. In order to support int64 vectors, a new VectorizedN class to handle vectors of arbitrary length. Below are the details:
1. Complete most of the int32 and int64 vectorization support including load, store, reduction, constant and conversion. The indirect indexing and masks will be addressed in follow-up PRs, after which, the legality checking logic in `CppVecKernelChecker` can be further simplified.
2. Util functions for conversion between integer and float vectors (in cpp_prefix.h and ATen vec). Ideally, we'd better move them from cpp_prefix.h to ATen vec to simplify cpp_prefix.h, will be addressed in follow-up PRs.
3. Introduced a new template class VectorizedN, designed to handle vectors of arbitrary length by encapsulating multiple Vectorized<T> instances. This class supports most of the operations of `Vectorized<T>`. It makes the support of int64 vectorization simpler. I will also apply it to bf16/fp16/int8 in the follow-up PRs for better efficiency. For example, bf16 currently only uses half of the vector lanes. With `VectorizedN`, we can use full of the lanes and map bf16 vector to `VectorizedN<float,2>` on conversion.
4. Basic data type support is added for uint32 and uint64 (in graph.py). Vectorization support will be added later but not of high priority due to fewer usages.

Next steps:

- [ ] Refactor the vector mask handling to support data types other than float. Currently vector masks are implemented with float vectors.
- [ ] Fully utilize vector lanes for bfloat16/float16/int8.
- [ ] Support indirect indexing with vectorized index via scalarization.
- [ ] Clean up `CppVecKernelChecker`.
- [ ] Simplify `cpp_prefix.h` including refactoring vector conversion logic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119001
Approved by: https://github.com/peterbell10, https://github.com/jansel
2024-02-08 17:38:49 +00:00
Jeff Daily
01abb5af21 additional support for float8_e4m3fnuz and _e5m2fnuz (#115214)
Follow up to #107586.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115214
Approved by: https://github.com/peterbell10, https://github.com/malfet
2024-01-22 18:33:41 +00:00
PyTorch MergeBot
b637fdc8b3 Revert "additional support for float8_e4m3fnuz and _e5m2fnuz (#115214)"
This reverts commit 74e1362499.

Reverted https://github.com/pytorch/pytorch/pull/115214 on behalf of https://github.com/PaliC due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/115214#issuecomment-1900815152))
2024-01-19 17:35:04 +00:00
Jeff Daily
74e1362499 additional support for float8_e4m3fnuz and _e5m2fnuz (#115214)
Follow up to #107586.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115214
Approved by: https://github.com/peterbell10
2024-01-19 00:50:18 +00:00
Edward Z. Yang
61a181e83c Report function name in stack trace annotations (#117459)
When working with internal flows, it can sometimes be ambiguous what
version of the code they are working with.  In this case, having the
function name available in the stack trace can help identify what you
are looking at.

Example now looks like:

```
[DEBUG]         # File: /data/users/ezyang/a/pytorch/a.py:5 in f, code: return x + x
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117459
Approved by: https://github.com/Skylion007
2024-01-15 00:29:13 +00:00
Aaron Gokaslan
3fe437b24b [BE]: Update flake8 to v6.1.0 and fix lints (#116591)
Updates flake8 to v6.1.0 and fixes a few lints using sed and some ruff tooling.
- Replace `assert(0)` with `raise AssertionError()`
- Remove extraneous parenthesis i.e.
  - `assert(a == b)` -> `assert a == b`
  - `if(x > y or y < z):`->`if x > y or y < z:`
  - And `return('...')` -> `return '...'`

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116591
Approved by: https://github.com/albanD, https://github.com/malfet
2024-01-03 06:04:44 +00:00
Xuehai Pan
199e07f108 [pytree][BE] update treespec num_children access (#116370)
Change `len(treespec.children_spes) -> treespec.num_children`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116370
Approved by: https://github.com/Skylion007
2023-12-24 20:54:32 +00:00
PyTorch MergeBot
684ce1b21d Revert "Assert that output could only be the last node of the FX graph (#115179)"
This reverts commit 4a9fb9832a.

Reverted https://github.com/pytorch/pytorch/pull/115179 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/115179#issuecomment-1845776365))
2023-12-07 17:26:27 +00:00
Oguz Ulgen
4a9fb9832a Assert that output could only be the last node of the FX graph (#115179)
Test Plan: unit tests

Differential Revision: D51856848

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115179
Approved by: https://github.com/Chillee
2023-12-06 08:17:16 +00:00
PyTorch MergeBot
8bb3cd192f Revert "Assert that output could only be the last node of the FX graph (#114973)"
This reverts commit a85df9eb0b.

Reverted https://github.com/pytorch/pytorch/pull/114973 on behalf of https://github.com/atalman due to Diff broke internal tests ([comment](https://github.com/pytorch/pytorch/pull/114973#issuecomment-1839290400))
2023-12-04 19:07:48 +00:00
Oguz Ulgen
a85df9eb0b Assert that output could only be the last node of the FX graph (#114973)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114973
Approved by: https://github.com/Chillee
2023-12-01 23:04:19 +00:00
PyTorch MergeBot
7c8d3639cf Revert "[fx] log the node when it's get eliminated (#112684)"
This reverts commit 6256d3710e.

Reverted https://github.com/pytorch/pytorch/pull/112684 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/112684#issuecomment-1831198778))
2023-11-29 04:31:15 +00:00
Shiyan Deng
6256d3710e [fx] log the node when it's get eliminated (#112684)
Summary: ATT

Test Plan: CI

Reviewed By: strisunshinewentingwang

Differential Revision: D50912413

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112684
Approved by: https://github.com/zyan0
2023-11-29 01:43:04 +00:00
Kaichao You
958f755a0e [FX][CodeGen] Make sure fx code is valid in python (#113345)
This PR fixes two cases when fx generated code is invalid in python (syntax error):

1. multiple type annotation in one line: `var1: annotation1, var2: annotation2 = function_call()`
2. invalid type annotation for scalars like `var1: f32[] = function_call()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113345
Approved by: https://github.com/ezyang
2023-11-10 21:12:16 +00:00
Peter Bell
66c32d099a Use pytree.arg_tree_leaves everywhere (#112394)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112394
Approved by: https://github.com/lezcano
ghstack dependencies: #112391, #112392, #112393
2023-10-31 15:57:06 +00:00
Peter Bell
bbd5b935e4 Use pytree.tree_leaves everywhere (#112324)
This changes all the instances I could find of `tree_flatten(...)[0]` or
`x, _ = tree_flatten` to use `tree_leaves`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112324
Approved by: https://github.com/lezcano
ghstack dependencies: #112327, #112323
2023-10-30 03:39:04 +00:00