Commit Graph

45 Commits

Author SHA1 Message Date
Shangdi Yu
75e2a9fae3 [annotate] add annotate_fn function decorator (#165703)
Example usage:

```
        @fx_traceback.annotate_fn({"pp_stage": 1})
        def example_function(x):
            return x * x

        class SimpleLinear(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(3, 2)

            def forward(self, x):
                with fx_traceback.annotate({"pp_stage": 0}):
                    y = self.linear(x)
                y = example_function(y)
                return y - 1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165703
Approved by: https://github.com/SherlockNoMad
2025-10-17 20:10:53 +00:00
PyTorch MergeBot
80d2ca7566 Revert "[annotate] add annotate_fn function decorator (#165703)"
This reverts commit f1d882212a.

Reverted https://github.com/pytorch/pytorch/pull/165703 on behalf of https://github.com/lw due to [GH job link](https://github.com/pytorch/pytorch/actions/runs/18585518705/job/52989521797) [HUD commit link](f1d882212a) ([comment](https://github.com/pytorch/pytorch/pull/165703#issuecomment-3415073467))
2025-10-17 11:23:13 +00:00
Shangdi Yu
f1d882212a [annotate] add annotate_fn function decorator (#165703)
Example usage:

```
        @fx_traceback.annotate_fn({"pp_stage": 1})
        def example_function(x):
            return x * x

        class SimpleLinear(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(3, 2)

            def forward(self, x):
                with fx_traceback.annotate({"pp_stage": 0}):
                    y = self.linear(x)
                y = example_function(y)
                return y - 1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165703
Approved by: https://github.com/SherlockNoMad
2025-10-17 07:18:47 +00:00
Shangdi Yu
5eddbb5e47 [annotate] Annotation should be mapped across submod (#165202)
The match for backward nodes might be in a different submod, so we should check all submod for potential matches.

In flex attention, this could happen if `mask_mod` has operations (such as index) that increase the seq_nr of the forward graph nodes. Then the backward flex_attention nodes cannot find a match in its own subgraph.

```
python test/functorch/test_aot_joint_with_descriptors.py -k preserve_annotate
```

Also tested on torchtitan joint_graph_runner branch. The flex_attention backward nodes are annotated now.

```
NGPU=8   CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"   LOG_RANK=0   TRAIN_FILE="torchtitan.train"   TORCHFT_LIGHTHOUSE="http://localhost:29510"   PYTORCH_ALLOC_CONF="expandable_segments:True"   torchrun     --nproc_per_node=8     --rdzv_backend c10d     --rdzv_endpoint="localhost:0"     --local-ranks-filter 0     --role rank     --tee 3     -m torchtitan.train     --job.config_file ./torchtitan/models/llama3/train_configs/debug_model.toml     --model.name joint_graph_runner.llama3     --compile.enable     --parallelism.data_parallel_shard_degree=2     --parallelism.tensor_parallel_degree=4     --model.flavor=debugmodel_flex_attn
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165202
Approved by: https://github.com/SherlockNoMad
2025-10-14 16:19:38 +00:00
Animesh Jain
4308b8a28f [dynamo] Support torch.fx.traceback.annotate (#164678)
Builds on top of https://github.com/pytorch/pytorch/pull/163673 and https://github.com/pytorch/pytorch/pull/164174. This will be used in the followup PRs to apply regional inductor compilation.

The existing implementation let Dynamo trace into the `torch.fx.traceback.annotate`, but thats not what we want. We want Dynamo to essentially run the torch.fx.traceback.annotate function in eager, so that every Fx node created in Dynamo Fx graph has the custom meta node.

What does not work?
* We still have to set the context manager `torch.fx.traceback.preserve_node_meta()` in the user code because CI was unhappy. This can be fixed but with some perseverance.
* This does not work with graph breaks yet. But we can solve that problem, if needed, in a separate PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164678
Approved by: https://github.com/SherlockNoMad, https://github.com/jansel, https://github.com/xmfan
2025-10-08 22:41:00 +00:00
PyTorch MergeBot
3040a5d294 Revert "[dynamo] Support torch.fx.traceback.annotate (#164678)"
This reverts commit 801e282f39.

Reverted https://github.com/pytorch/pytorch/pull/164678 on behalf of https://github.com/izaitsevfb due to breaks executorch internally, see [D84068062](https://www.internalfb.com/diff/D84068062?entry_point=16) ([comment](https://github.com/pytorch/pytorch/pull/164678#issuecomment-3379281844))
2025-10-08 01:49:34 +00:00
Animesh Jain
801e282f39 [dynamo] Support torch.fx.traceback.annotate (#164678)
Builds on top of https://github.com/pytorch/pytorch/pull/163673 and https://github.com/pytorch/pytorch/pull/164174. This will be used in the followup PRs to apply regional inductor compilation.

The existing implementation let Dynamo trace into the `torch.fx.traceback.annotate`, but thats not what we want. We want Dynamo to essentially run the torch.fx.traceback.annotate function in eager, so that every Fx node created in Dynamo Fx graph has the custom meta node.

This does not work with graph breaks yet. But we can solve that problem, if needed, in a separate PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164678
Approved by: https://github.com/SherlockNoMad, https://github.com/jansel, https://github.com/xmfan
2025-10-07 14:54:26 +00:00
Animesh Jain
361c5d362c [fx][traceback] Actually disable preservation of node metadata when enable=False (#164772)
This will come in handy when we run graph passes that add new nodes, and
create_proxy can add seq_nr meta.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164772
Approved by: https://github.com/SherlockNoMad
2025-10-06 23:39:12 +00:00
PyTorch MergeBot
cfc5cc17dc Revert "[dynamo] Support torch.fx.traceback.annotate (#164678)"
This reverts commit 2883b5ab77.

Reverted https://github.com/pytorch/pytorch/pull/164678 on behalf of https://github.com/izaitsevfb due to fails inductor:max_autotune tests internally, see D83948169 ([comment](https://github.com/pytorch/pytorch/pull/164678#issuecomment-3374407009))
2025-10-06 22:03:42 +00:00
Animesh Jain
2883b5ab77 [dynamo] Support torch.fx.traceback.annotate (#164678)
Builds on top of https://github.com/pytorch/pytorch/pull/163673 and https://github.com/pytorch/pytorch/pull/164174. This will be used in the followup PRs to apply regional inductor compilation.

The existing implementation let Dynamo trace into the `torch.fx.traceback.annotate`, but thats not what we want. We want Dynamo to essentially run the torch.fx.traceback.annotate function in eager, so that every Fx node created in Dynamo Fx graph has the custom meta node.

This does not work with graph breaks yet. But we can solve that problem, if needed, in a separate PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164678
Approved by: https://github.com/SherlockNoMad, https://github.com/jansel, https://github.com/xmfan
2025-10-06 02:59:24 +00:00
Sherlock Huang
10e69a6e17 Preserve user annotation in graph (#163673)
```
import torch
import torch.fx.traceback as fx_traceback
import torch.export

class M(torch.nn.Module):
    def forward(self, x):
        with fx_traceback.annotate({"pp_stage": 0}):
            with fx_traceback.annotate({"fdsp_bucket": 0}):
                x = x + 1
            x = x - 2
            with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}):
                x = x * 2
        x = x / 3
        return x

m = M()

with fx_traceback.preserve_node_meta():
    ep = torch.export.export(m, (torch.randn(10),))

for node in ep.graph.nodes:
    if node.op == "call_function":
        print(f"{node.target}, {node.meta.get("custom", {})}")

```

prints

```
aten.add.Tensor, {'pp_stage': 0, 'fdsp_bucket': 0}
aten.sub.Tensor, {'pp_stage': 0}
aten.mul.Tensor, {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1}
aten.div.Tensor, {}
```

TODOs:
- run_decomposition is failing
- Need to test with the new full graph capture + aot_export_joint apis
- Need to make the annotation propagate through autograd engine to reach the bw nodes. Sample impl here: https://github.com/pytorch/pytorch/pull/83558
- Edward want to restrict the key in custom field to be top-level singleton objects only
- also need to take care of metadata merging when passes are fusing nodes

Thanks @angelayi  for contributing the dynamo fixes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163673
Approved by: https://github.com/albanD, https://github.com/angelayi
2025-09-25 15:50:15 +00:00
Shangdi Yu
4e90441133 Add signpost to provenance tracking error (#160755)
Summary: As title, add signpost to better track error when computing provenance tracking related debugging information

Test Plan:
CI

Rollback Plan:

Differential Revision: D80292285

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160755
Approved by: https://github.com/angelayi
2025-08-18 19:17:47 +00:00
Shangdi Yu
aa99e0958f Separate provenance tracking to different levels (#160383)
Summary: as title. We've got request from various parties who are interested in turning on the provenance tracking by default. In this PR, we prepare to turn on part of the provenance tracking that doesn't have too much overhead by default.

- Change `provenance_tracking` config to `provenance_tracking_level`
- turn on the following provenance tracking by default when `basic_provenance_tracking`=True
    - `set_kernel_post_grad_provenance_tracing` for kernels, this add mapping between triton kernels and post_grad nodes
    - `dump_inductor_provenance_info` if we're dumping tlparse log
    - `get_graph_provenance_json` and dump `reate_mapping_pre_post_grad_nodes`. This creates mapping between pre_grad and post_grad nodes. Since we're not turning on the provenance tracking in GraphTransformObserver by default, the mapping here maybe incomplete/limited.
    - add stack trace from post grad nodes to inductor IR nodes
    - add exception swallowing for all functions above

Test Plan:
CI

Rollback Plan:

Differential Revision: D80031559

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160383
Approved by: https://github.com/angelayi
2025-08-15 04:59:35 +00:00
Songhao Jia
6d31d38965 recovering node source from dict (#158373) (#158473)
Summary:

this diff recovers NodeSource object from its dict representation, which is crucial for NodeSource serde.

Test Plan:
ci

Rollback Plan:

Differential Revision: D78434648

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158473
Approved by: https://github.com/angelayi
2025-07-17 17:00:19 +00:00
FFFrog
415dfabe9b [Easy] Fix the format (#158450)
When I modify the code located in test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg,
some unrelated format error occurred.

```Python
Lint for torch/_inductor/fx_passes/fuse_attention.py:

  Error (CODESPELL) spelling error
    Failed due to ValueError:
    /pytorch/pytorch/torch/_inductor/fx_passes/fuse_attention.py:587: differnt
    ==> different

    Please either fix the error or add the word(s) to the dictionary file.
    HINT: all-lowercase words in the dictionary can cover all case variations.

Lint for torch/fx/traceback.py:

  Error (MYPY) [assignment]
    Incompatible types in assignment (expression has type "str", variable has
    type "None")

        101  |
        102  |    def _get_action_string(self):
        103  |        if self._action_string is None:
        104  |            self._action_string = "+".join([a.name.lower() for a in self.action])
        105  |        return self._action_string
        106  |
        107  |    def print_readable(self, indent=0):

  Error (MYPY) [assignment]
    Incompatible types in assignment (expression has type "dict[str, Any]",
    variable has type "None")

        121  |        if self._dict is None:
        122  |            # Convert the object to a dictionary
        123  |            action_string = self._get_action_string()
        124  |            self._dict = {
        125  |                "name": self.name,
        126  |                "target": self.target,
        127  |                "graph_id": self.graph_id,

  Error (MYPY) [return-value]
    Incompatible return value type (got "None", expected "dict[Any, Any]")

        130  |                "from_node": [node.to_dict() for node in self.from_node],
        131  |            }
        132  |
        133  |        return self._dict
        134  |
        135  |    def __eq__(self, other: object):
        136  |        if not isinstance(other, NodeSource):
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158450
Approved by: https://github.com/Skylion007
2025-07-17 04:56:10 +00:00
PyTorch MergeBot
14ecc03361 Revert "recovering node source from dict (#158373)"
This reverts commit 4d055982e3.

Reverted https://github.com/pytorch/pytorch/pull/158373 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/158373#issuecomment-3080093479))
2025-07-16 19:55:21 +00:00
Songhao Jia
4d055982e3 recovering node source from dict (#158373)
Summary: this diff recovers NodeSource object from its dict representation, which is crucial for NodeSource serde.

Test Plan:
ci

Rollback Plan:

Differential Revision: D78363882

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158373
Approved by: https://github.com/yushangdi
2025-07-16 18:46:09 +00:00
David Berard
0b19d463d9 forward fix lint (#158448)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158448
Approved by: https://github.com/adamomainz
2025-07-16 14:55:33 +00:00
Songhao Jia
0cb36e2d62 cache dict and string rep for better perf (#158372)
Summary: NodeSouce should not be updated after created, so that it would be better if we cache its dict and string representation for better perf.

Test Plan:
ci

Rollback Plan:

Reviewed By: yushangdi

Differential Revision: D78298501

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158372
Approved by: https://github.com/yushangdi
2025-07-16 02:15:32 +00:00
Songhao Jia
011026205a make node source hashable (#158322)
Summary: as title

Test Plan:
ci

Rollback Plan:

Reviewed By: yushangdi

Differential Revision: D78296410

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158322
Approved by: https://github.com/yushangdi
2025-07-15 19:31:00 +00:00
Songhao Jia
1c6057fd17 add eq function to NodeSource (#158170)
Summary: add eq function to NodeSouce by comparing their dict representation.

Test Plan:
ci

Rollback Plan:

Differential Revision: D78200762

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158170
Approved by: https://github.com/ezyang, https://github.com/yushangdi
2025-07-15 00:50:06 +00:00
Shangdi Yu
a4e4368157 add node mapping processing (#146103)
Summary:
Add `node_mapping = create_node_mapping(pre_grad_graph_id, inductor_post_to_pre_grad_nodes, debug_info)`, to produce a `inductor_provenance_tracking_node_mappings.json` file. This file will be used by the provenance tracking highlighter tool to create provenance visualization.

`inductor_triton_kernel_to_post_grad_nodes.json` and `inductor_provenance_tracking_node_mappings.json` files are not dumped if they are both empty. So it's removed from some of the `test_structured_trace` tests.

Test Plan:
CI
```
buck run mode/dev-nosan  fbcode//caffe2/test:fx -- -r graph_provenance

buck run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing

python test/dynamo/test_structured_trace.py
```

Differential Revision: D68190173

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146103
Approved by: https://github.com/chenyang78
2025-02-01 08:29:29 +00:00
Aaron Orenstein
0b2a3687b9 PEP585 update - torch/fx (#145166)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145166
Approved by: https://github.com/bobrenjc93
2025-01-20 18:11:54 +00:00
Shangdi Yu
379b54603a [Inductor] [bc-breaking] Node Level provenance tracking (#144277)
Summary:

- use GraphTransformObserver + replace_node hooks to track node sources when they are replaced
- add pre_grad_graph tracking to tlparse
- add the node provenance information to post_grad_graph tlparse. This is for the frontend to create a mapping between pre_grad and post_grad graph. See an example frontend (this is just a prototype) here:  https://drive.google.com/file/d/1cMHH_0y4FJUSS9tATwGQvA72O0Lth8eh/view?usp=sharing
- change "action" of NodeSource from a single action to a list of actions.

- It's BC-Breaking because we removed `GraphTransformObserver`'s class methods `on_node_erase` and `on_node_erase` .

https://docs.google.com/document/d/1dGh9myqNhywmbfP0Quzx_f04bghDFlj8cawj8MopiO8/edit?tab=t.0

The front-end code that takes in the tlparse result is in https://github.com/yushangdi/compiler_explorer.
ghstack-source-id: 260390519

Test Plan:
```
buck2 run mode/dev-nosan fbcode//caffe2/test:fx -- -r test_graph_transform_observer
buck run mode/dev-nosan  fbcode//caffe2/test:fx -- -r node_source
buck run mode/dev-nosan  fbcode//caffe2/test:fx -- -r graph_provenance
```

Front-end example screenshots on a real model, 93% coverage rate between pre_grad_graph and post_grad_graph

 {F1973584210}{F1973584209}

```
buck2 build --show-output mode/opt -c=python.package_style=inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.split-dwarf=true -c fbcode.nvcc_arch=a100,h100 caffe2/torch/fb/model_transform/experimental/benchmark:mts_gpu_benchmark

MODEL_ENTITY_ID=644688112
SNAPSHOT_ID=32
MODULE=merge

TORCH_COMPILE_DEBUG=1 CUDA_VISIBLE_DEVICES=7 TORCH_LOGS="+inductor,+schedule,output_code,graph_code" TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 ../buck-out/v2/gen/fbcode/ec86b05dd59e84db/caffe2/torch/fb/model_transform/experimental/benchmark/__mts_gpu_benchmark__/mts_gpu_benchmark.par --local-model /home/bahuang/models/${MODEL_ENTITY_ID}/${SNAPSHOT_ID}/gpu_lowering/input.predictor.disagg.gpu.merge --lower-backend AOT_INDUCTOR_EP --gpu-trace --aot-inductor-config="{'max_autotune':
True}"

buck2 run mode/dev-nosan fbcode//caffe2/test/inductor:auto_functionalize
```

Differential Revision: D65006709

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144277
Approved by: https://github.com/desertfire
2025-01-09 22:06:51 +00:00
Shangdi Yu
bcddae14ec Enhance "from_node" node meta to track source recursively (#142066)
Summary:
Change the "from_node" node meta format to be able to track the provenance of nodes recursively.

The new "from_node" format is a a list node NodeSource:

```
class NodeSource:
	self.node_name: str
	self.target: str
	self.graph_id: int
	self.pass_name: str
	self.action: str
	self.from_node: List[NoedSource]
```

This is in preparation for the inductor provenance tracking. For background, the inductor provenance tracking doc: https://docs.google.com/document/d/1dGh9myqNhywmbfP0Quzx_f04bghDFlj8cawj8MopiO8/edit?fbclid=IwZXh0bgNhZW0CMTEAAR0jUQ0Tf4ROLDED8Y_eIzrU0KVZVdRmyIQLp-avt-kGRPI_VgYVNyjH_q0_aem_HCQ_pxHDiwOkO9mQyWB2-g&tab=t.0 (internal only),

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_unflatten_multiple_graphs_state
buck run mode/dev-nosan caffe2/test:fx -- -r node_source
```

Differential Revision: D66737916

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142066
Approved by: https://github.com/avikchaudhuri
2024-12-09 23:39:15 +00:00
Xuehai Pan
abbd71d29d [BE][Easy] enable PYFMT for torch.fx (#138443)
Reproduce command:

```bash
ghstack checkout https://github.com/pytorch/pytorch/pull/138443
git checkout HEAD~1 torch/
lintrunner -a --take "PYFMT" --all-files
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138443
Approved by: https://github.com/ezyang
2024-10-21 19:15:49 +00:00
Aaron Orenstein
ed86ac2f25 [BE] typing for decorators - fx/_compatibility (#134054)
Summary: See #131429

Test Plan: unit tests pass

Differential Revision: D61493706

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134054
Approved by: https://github.com/oulgen
2024-08-26 04:00:27 +00:00
PyTorch MergeBot
945bf78894 Revert "[BE] typing for decorators - fx/_compatibility (#131568)"
This reverts commit 193f62fde9.

Reverted https://github.com/pytorch/pytorch/pull/131568 on behalf of https://github.com/clee2000 due to same as https://github.com/pytorch/pytorch/pull/131572#issuecomment-2254328359 but I clicked the wrong link by accident.  This is where it actually starts ([comment](https://github.com/pytorch/pytorch/pull/131568#issuecomment-2254330781))
2024-07-28 03:43:39 +00:00
Aaron Orenstein
193f62fde9 [BE] typing for decorators - fx/_compatibility (#131568)
See #131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131568
Approved by: https://github.com/justinchuby, https://github.com/oulgen, https://github.com/zou3519
2024-07-25 22:24:19 +00:00
Aaron Orenstein
5a0068cc69 [BE] mypy: disallow untyped decorators (#131428)
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations.

Step 1 - Enable the error and override in all the offending files.

#131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428
Approved by: https://github.com/justinchuby, https://github.com/oulgen
2024-07-23 21:50:55 +00:00
Aaron Orenstein
038b927590 Flip default value for mypy disallow_untyped_defs [7/11] (#127844)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127844
Approved by: https://github.com/oulgen
ghstack dependencies: #127842, #127843
2024-06-08 18:49:45 +00:00
soulitzer
13462ecd27 Update preserve_node_meta to reset torch.fx.traceback.current_meta (#125500)
Fixes https://github.com/pytorch/pytorch/issues/122766

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125500
Approved by: https://github.com/xmfan, https://github.com/ezyang
2024-05-08 14:30:34 +00:00
soulitzer
b8bd3bb30a Fix aot_autograd seq_nr logic (#118249)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118249
Approved by: https://github.com/zou3519
ghstack dependencies: #117552, #118234
2024-01-25 22:56:20 +00:00
soulitzer
3cc5c42a23 Fix aot sequence_nr to reset bwd flag (#107210)
The way the aot autograd sequence_nr tracking works is that we run the aot export logic, the dynamo captured forward graph is run under an fx.Interpreter, which iterates through the nodes of the forward graph while setting the `current_metadata`.
Since during backward what is run doesn't correspond to any node during forward, we fallback to the global `current_metadata`. And since this global metadata is ends up being shared between runs, that leads to weirdness if we forget to reset things, e.g., depending whether this is the first test run, the printed results will be different.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107210
Approved by: https://github.com/bdhirsh
2023-08-24 16:58:12 +00:00
Alex Settle
9ba0558d48 Add sequence_nr to aot_autograd to map forward ops to their corresponding backward ops (#103129)
Fixes #102375

Sequence_nr increments in the forward pass and decrements in the backward pass.  Backward ops with the same sequence_nr as a forward op represent the backward implementation for the op.  The long term goal is to make this information available to the profiler so users can observe which ops are fused by the inductor openai triton kernels.

Added a test for this feature **test/dynamo/test_aot_autograd.py::AotAutogradFallbackTests::test_aot_sequence_nr**.  The test case uses **aot_export_module()** to create a joint fwd/bwd fx graph.  Then it walks all the nodes in fx graph using fx_graph.graph.nodes.   The seq_nr of each node is recorded in node.meta.  During the fwd pass the seq_nr increments and it decrements during the bwd pass.  This allows the user to map forward ops to their corresponding bwd ops which is useful for performance analysis.

Expected output from the test case.

 SeqNr|OrigAten|SrcFn
0|aten.convolution.default|l__self___conv1
0|aten.add.Tensor|l__self___bn1
1|aten._native_batch_norm_legit_functional.default|l__self___bn1
2|aten.relu.default|l__self___relu1
3|aten.add.Tensor|add
4|aten.view.default|flatten
5|aten.t.default|l__self___fc1
6|aten.unsqueeze.default|l__self___fc1
7|aten.mm.default|l__self___fc1
8|aten.squeeze.dim|l__self___fc1
9|aten.add.Tensor|l__self___fc1
10|aten.sub.Tensor|l__self___loss_fn
11|aten.abs.default|l__self___loss_fn
12|aten.mean.default|l__self___loss_fn
12|aten.ones_like.default|
12|aten.expand.default|
12|aten.div.Scalar|
11|aten.sgn.default|
11|aten.mul.Tensor|
8|aten.unsqueeze.default|
7|aten.t.default|
7|aten.mm.default|
7|aten.t.default|
7|aten.t.default|
7|aten.mm.default|
6|aten.squeeze.dim|
5|aten.t.default|
4|aten.view.default|
2|aten.threshold_backward.default|
1|aten.native_batch_norm_backward.default|
0|aten.convolution_backward.default|
0|aten.add.Tensor|

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103129
Approved by: https://github.com/soulitzer
2023-08-02 00:52:52 +00:00
Sherlock Huang
a770295af4 Don't alter original node's meta in Interpreter (#105880)
Test Plan: OSS CI

Reviewed By: angelayi

Differential Revision: D47740058

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105880
Approved by: https://github.com/angelayi
2023-07-26 03:44:58 +00:00
Sherlock Huang
36fe31f537 [Reland] Refactor stack_trace preservation for node meta preservation (#90803) (#92400)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90803
Approved by: https://github.com/jerryzh168, https://github.com/albanD
ghstack-source-id: 5848cca08ef5d6f8868f4f79d8bc29711e9a52c2

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92400
Approved by: https://github.com/jerryzh168
2023-01-30 23:30:43 +00:00
PyTorch MergeBot
498be7ed25 Revert "Refactor stack_trace preservation for node meta preservation (#90803)"
This reverts commit 0f1302eeae.

Reverted https://github.com/pytorch/pytorch/pull/90803 on behalf of https://github.com/DanilBaibak due to Break internal build
2023-01-10 10:44:28 +00:00
Sherlock Huang
0f1302eeae Refactor stack_trace preservation for node meta preservation (#90803)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90803
Approved by: https://github.com/jerryzh168, https://github.com/albanD
2023-01-09 23:23:27 +00:00
Sherlock Huang
eeba9d5ab4 Preserve node's meta during fx.transformation (#90737)
We wish to preserve node.meta over fx.Transformer transformation and aot_autograd. This will preserve all the meta fields in the original node, including stack_trace, nn_module_stack, val, tensor_meta...

Sample

Here's a graph produced by Dynamo.
```
class GraphModule(torch.nn.Module):
    def forward(self, x : torch.Tensor, y : torch.Tensor):
        # File: /scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py:35, code: a = torch.cos(x)
        cos = torch.cos(x);  x = None

        # File: /scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py:36, code: b = torch.sin(y)
        sin = torch.sin(y);  y = None

        # File: /scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py:37, code: return a + b
        add = cos + sin;  cos = sin = None
        return (add,)

x {'creation_timestamp': 0, 'stack_trace': '  File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 45, in forward\n    def forward(self, x, y):\n'}
y {'creation_timestamp': 0, 'stack_trace': '  File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 45, in forward\n    def forward(self, x, y):\n'}
cos {'creation_timestamp': 3, 'nn_module_stack': {'self_block': "<class '__main__.Block'>"}, 'stack_trace': '  File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 35, in forward\n    a = torch.cos(x)\n |   File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 46, in forward\n    return self.block(x, y)\n'}
sin {'creation_timestamp': 4, 'nn_module_stack': {'self_block': "<class '__main__.Block'>"}, 'stack_trace': '  File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 36, in forward\n    b = torch.sin(y)\n |   File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 46, in forward\n    return self.block(x, y)\n'}
add {'creation_timestamp': 4, 'nn_module_stack': {'self_block': "<class '__main__.Block'>"}, 'stack_trace': '  File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 37, in forward\n    return a + b\n |   File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 46, in forward\n    return self.block(x, y)\n'}
output {'creation_timestamp': 4}
```

After lowering to aten graph with aot_autograd_simplified()
```
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: f32[2, 3], primals_2: f32[2, 3]):
        # File: /scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py:35, code: a = torch.cos(x)
        cos: f32[2, 3] = torch.ops.aten.cos.default(primals_1)

        # File: /scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py:36, code: b = torch.sin(y)
        sin: f32[2, 3] = torch.ops.aten.sin.default(primals_2)

        # File: /scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py:37, code: return a + b
        add: f32[2, 3] = torch.ops.aten.add.Tensor(cos, sin);  cos = sin = None
        return [add, primals_2, primals_1]

primals_1 {'val': FakeTensor(FakeTensor(..., device='meta', size=(2, 3)), cpu), 'tensor_meta': TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=True, stride=(3, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}
primals_2 {'val': FakeTensor(FakeTensor(..., device='meta', size=(2, 3)), cpu), 'tensor_meta': TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=True, stride=(3, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}
cos {'creation_timestamp': 3, 'nn_module_stack': {'self_block': "<class '__main__.Block'>"}, 'stack_trace': '  File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 35, in forward\n    a = torch.cos(x)\n |   File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 46, in forward\n    return self.block(x, y)\n', 'val': FakeTensor(FakeTensor(..., device='meta', size=(2, 3)), cpu), 'tensor_meta': TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}
sin {'creation_timestamp': 4, 'nn_module_stack': {'self_block': "<class '__main__.Block'>"}, 'stack_trace': '  File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 36, in forward\n    b = torch.sin(y)\n |   File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 46, in forward\n    return self.block(x, y)\n', 'val': FakeTensor(FakeTensor(..., device='meta', size=(2, 3)), cpu), 'tensor_meta': TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}
add {'creation_timestamp': 4, 'nn_module_stack': {'self_block': "<class '__main__.Block'>"}, 'stack_trace': '  File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 37, in forward\n    return a + b\n |   File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 46, in forward\n    return self.block(x, y)\n', 'val': FakeTensor(FakeTensor(..., device='meta', size=(2, 3)), cpu), 'tensor_meta': TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}
output {}
```

Notice that output fx node have creation_time_stamp, nn_module_stack and stack_trace copied from the original fx node.
val and tensor_meta were latter populated by a subsequent fake_tensor_propagation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90737
Approved by: https://github.com/jerryzh168
2023-01-06 17:21:02 +00:00
Sherlock Huang
caf3d5319f Symintify numel(), infer_size, prims.elementwise_meta (#88956)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88956
Approved by: https://github.com/ezyang
2022-11-20 00:42:03 +00:00
PyTorch MergeBot
8ad39536d7 Revert "Symintify numel(), infer_size, prims.elementwise_meta (#88956)"
This reverts commit ce2f8700ba.

Reverted https://github.com/pytorch/pytorch/pull/88956 on behalf of https://github.com/ezyang due to somehow breaks torch.numel
2022-11-19 21:47:55 +00:00
Sherlock Huang
ce2f8700ba Symintify numel(), infer_size, prims.elementwise_meta (#88956)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88956
Approved by: https://github.com/ezyang
2022-11-16 03:36:00 +00:00
Sherlock Huang
a7baad04f6 Preserve stack trace for backward nodes over AOTAutograd (#83558)
For the following program.
```
def my_relu(a):
    return a.relu()

def func(a, b):
    a = torch.nn.Linear(10, 10)(a)
    d = torch.square(b)
    d = my_relu(d)
    loss = d.sum()

    return loss

with torchdynamo.optimize("aot_nop"):
    x = torch.rand(10, 10, requires_grad=True)
    y = torch.rand(10, 10, requires_grad=True)
    out = func(x, y)
```

It would generate the following fx graph with stack_trace populated in both forward and backward nodes.
```
def forward(self, primals, tangents):
    primals_1, primals_2, primals_3, primals_4, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    t_default = torch.ops.aten.t.default(primals_3);  primals_3 = None
    addmm_default = torch.ops.aten.addmm.default(primals_4, primals_1, t_default);  primals_4 = primals_1 = t_default = None
    pow_tensor_scalar = torch.ops.aten.pow.Tensor_Scalar(primals_2, 2)
    relu_default = torch.ops.aten.relu.default(pow_tensor_scalar);  pow_tensor_scalar = None
    detach_default = torch.ops.aten.detach.default(relu_default)
    sum_default = torch.ops.aten.sum.default(relu_default);  relu_default = None
    is_same_size_default = torch.ops.aten.is_same_size.default(sum_default, tangents_1)
    expand_default = torch.ops.aten.expand.default(tangents_1, [10, 10]);  tangents_1 = None
    detach_default_1 = torch.ops.aten.detach.default(detach_default);  detach_default = None
    threshold_backward_default = torch.ops.aten.threshold_backward.default(expand_default, detach_default_1, 0);  expand_default = detach_default_1 = None
    pow_tensor_scalar_1 = torch.ops.aten.pow.Tensor_Scalar(primals_2, 1.0);  primals_2 = None
    mul_scalar = torch.ops.aten.mul.Scalar(pow_tensor_scalar_1, 2.0);  pow_tensor_scalar_1 = None
    mul_tensor = torch.ops.aten.mul.Tensor(threshold_backward_default, mul_scalar);  threshold_backward_default = mul_scalar = None
    return pytree.tree_unflatten([sum_default, None, mul_tensor, None, None], self._out_spec)

====== joint graph =======
primals_1 None
primals_2 None
primals_3 None
primals_4 None
tangents_1 None
t_default   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 12, in func
    def func(a, b):
  File "/fsx/users/bahuang/repos/pytorch_fsx/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)

addmm_default   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 12, in func
    def func(a, b):
  File "/fsx/users/bahuang/repos/pytorch_fsx/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)

pow_tensor_scalar   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
    d = torch.square(b)

relu_default   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
    d = my_relu(d)
  File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
    return a.relu()

detach_default   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
    d = my_relu(d)
  File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
    return a.relu()

sum_default
is_same_size_default
expand_default
detach_default_1   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
    d = my_relu(d)
  File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
    return a.relu()

threshold_backward_default   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 15, in func
    d = my_relu(d)
  File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 10, in my_relu
    return a.relu()

pow_tensor_scalar_1   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
    d = torch.square(b)

mul_scalar   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
    d = torch.square(b)

mul_tensor   File "/fsx/users/bahuang/repos/pytorch_fsx/test.py", line 14, in func
    d = torch.square(b)

output None
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83558
Approved by: https://github.com/albanD
2022-08-18 22:13:04 +00:00
Sherlock Huang
6915676448 Preserve node's stack trace during retrace (#83050)
AOTAutograd retraces graph module produced by torch dynamo, this PR preserves the stack trace in the original fx.Node.

Differential Revision: [D38595638](https://our.internmc.facebook.com/intern/diff/D38595638)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83050
Approved by: https://github.com/ezyang, https://github.com/voznesenskym
2022-08-11 04:18:14 +00:00