```
import torch
import torch.fx.traceback as fx_traceback
import torch.export
class M(torch.nn.Module):
def forward(self, x):
with fx_traceback.annotate({"pp_stage": 0}):
with fx_traceback.annotate({"fdsp_bucket": 0}):
x = x + 1
x = x - 2
with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}):
x = x * 2
x = x / 3
return x
m = M()
with fx_traceback.preserve_node_meta():
ep = torch.export.export(m, (torch.randn(10),))
for node in ep.graph.nodes:
if node.op == "call_function":
print(f"{node.target}, {node.meta.get("custom", {})}")
```
prints
```
aten.add.Tensor, {'pp_stage': 0, 'fdsp_bucket': 0}
aten.sub.Tensor, {'pp_stage': 0}
aten.mul.Tensor, {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1}
aten.div.Tensor, {}
```
TODOs:
- run_decomposition is failing
- Need to test with the new full graph capture + aot_export_joint apis
- Need to make the annotation propagate through autograd engine to reach the bw nodes. Sample impl here: https://github.com/pytorch/pytorch/pull/83558
- Edward want to restrict the key in custom field to be top-level singleton objects only
- also need to take care of metadata merging when passes are fusing nodes
Thanks @angelayi for contributing the dynamo fixes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163673
Approved by: https://github.com/albanD, https://github.com/angelayi
Landing this instead of https://github.com/pytorch/pytorch/pull/162994.
Here is how i think the whole dynamo + frame construction logic work:
1) There is no way to create a frame object in python land as this is created in runtime from cpython. So that's why aot_compile creates FrameInfo this way. (kind of like simulating the runtime) i guess you could write your own very simple eval_frame.c where you can interject the frame construction but we probably don't want that.
2) When there is no wrapper (the old export or aot_compile), we first assign sources by iterating over f_locals which contain both local args and closure variables (this is implementation details of cpython frame construction). So thats why closure variables end up getting LocalSource names as can be shown in this test case (f6ea41ead2/test/export/test_export.py (L1369)). Note that L["self"] here means we are referring to local object self. Important thing to keep in mind here is this self is not actually model self, but the outer self.
3) When we switch to wrapper case, we end up trying to inline the original inner module. When doing so, we need to track all local and closures for this inner module as can be seen here (f6ea41ead2/torch/_dynamo/variables/functions.py (L463)) Here we are not looking into inner frame's f_locals but just directly look at closures. I guess this is because we are one more frame up so there is no access to frame f_locals at this point. And it is probably not good idea to change dynamo's logic here. As a result, i get following error message that is different from old export:
"While exporting, we found certain side effects happened in the model.forward. Here are the list of potential sources you can double check: ["L['self']._export_root.forward.__func__.__closure__[1].cell_contents.bank", "L['self']._export_root.forward.__func__.__closure__[1].cell_contents.bank_dict", "L['self']._export_root.forward.__func__.__closure__[0].cell_contents"]"
My initial attempt of solving this was taking inner closures and put them to f_locals for the frame i am constructing which turned out too compilcated because we needed to muck around bytecode instructions as well. So i am thinking we should just update the test to reflect new names and follow up with better post-processing step to have better names.
Differential Revision: [D82582029](https://our.internmc.facebook.com/intern/diff/D82582029)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163107
Approved by: https://github.com/avikchaudhuri
Fixes#160547
### Summary:
bug
```
def test_namedtuple(self):
from collections import namedtuple
Point = namedtuple('Point', 'x y')
class M(torch.nn.Module):
def forward(self, x, y):
return x + y
inp = Point(torch.ones(3), torch.ones(3))
print(M()(*inp))
# errors
ep = torch.export.export(M(), inp, strict=False)
print(ep)
# succeeds
ep = torch.export.export(M(), inp, strict=True)
print(ep)
# workaround could be to convert namedtuple to a kwarg
inp_kwargs = {field: getattr(inp, field) for field in inp._fields}
ep = torch.export.export(M(), (), inp_kwargs)
print(ep)
```
FIx :
namedtuple is subclass of tuple
but namedtuple is not expected
So, this change handles named tuple case
I have added 🧪 test case for this as well
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162959
Approved by: https://github.com/angelayi
Co-authored-by: Angela Yi <angelayi@meta.com>
Currently OutputGraphGuardsState is separated out as a serializable interface for OutputGraph, but some of the typing around it is incorrect in dynamo's guards.py and output_graph.py: more fields are used by code than claimed by OutputGraphGuardsState, and it works because either the full OutputGraph is passed in or the parts that use those fields are dead when OutputGraphGuardsState is passed in.
In this PR we try to further separate the necessary fields of OutputGraph that should be retained by a full graph capture mechanism, not just limited to dynamo (as it is currently) but also something like make_fx (in the future). Since these fields do not need to be serialized, the result is an intermediate "common" data structure that is between OutputGraphGuardsState and OutputGraph in the inheritance hierarchy.
Differential Revision: D81718791
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162211
Approved by: https://github.com/zhxchen17
These decompositions take precedence before CIA decomps in fake tensor prop, as a result, we would hit this implementation for all where overloads which is wrong in some cases. For the overloads that can't be implemented by this decomp, we just run the default CIA impl. Previously this doesn't matter because in post-dispatch IR, aten.where would have decomposed but when user tries to preserve aten.where this issue will surface because fake tensor will start seeing aten.where.
Differential Revision: [D82604702](https://our.internmc.facebook.com/intern/diff/D82604702)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163138
Approved by: https://github.com/henryoier, https://github.com/ezyang
Note this could change suggest_memory_format behaviour for unbacked
we used to return True for are_strides_like_channels_last sometimes even when results undecided
now when its not decided we return False.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162354
Approved by: https://github.com/aorenste
Investigated together with @pyemma and @taotaohuang001
## Problem
when calling exported module with dict nested in the args tuple, it will make following complaits
```
Traceback (most recent call last):
File "/home/chzhu/infinitrain/test_torch_export.py", line 32, in <module>
print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)}))
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 848, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 424, in __call__
raise e
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 411, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
return inner()
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1806, in inner
args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc]
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
return fn(*args, **kwargs)
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 81, in _check_input_constraints_pre_hook
flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec)
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 64, in _check_inputs_match
raise ValueError( # noqa: B904
ValueError: Trying to flatten user inputs with exported input tree spec:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a1', 'a2'], [*,
*])]),
TreeSpec(dict, [], [])])
but actually got inputs with tree spec of:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a2', 'a1'], [*,
*])]),
TreeSpec(dict, [], [])]).
Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing.
```
## How to reproduce the issue
```python
import torch
# create a nn.Module with data_batch as input and output as output
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear = torch.nn.Linear(10, 1)
def forward(self, data_batch):
h1 = self.linear(data_batch["a1"])
h2 = self.linear(data_batch["a2"])
return h1 + h2
# torch export this module
model = MyModel()
example_args_forward = (
{
"a1": torch.randn(10),
"a2": torch.randn(10),
},
)
exported_model = torch.export.export(model, example_args_forward, strict=True)
# save the exported model
torch.export.save(exported_model, "exported_model.pt2")
# load the exported model
exported_model = torch.export.load("exported_model.pt2").module()
# run the exported model
print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)}))
```
## Root Cause
Input spec is encoded as [TreeSpec](582d278983/torch/utils/_pytree.py (L1059)) in torch export. With (args, kwargs) at the top level. When we call the exported model, it has a pre-execution [hook](582d278983/torch/export/_unlift.py (L66)) to check the input TreeSpec matches the received TreeSpec, where in Treespec, the dict key order is preserved. Something like
TreeSpec(dict, ['a2', 'a1'], [*,*])
To workaround this, the input check reorders [kwargs](582d278983/torch/export/_unlift.py (L67)), that is why kwargs can be out of order. But the dict nested in the args is not re-ordered, so any re-ordering of the keys will throw errors.
## Solution
Update eq_spec to handle the dict case, where we only guarantee that key set is the same without ordering constraints.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162618
Approved by: https://github.com/angelayi
This makes gemma3 exportable on transformers=4.55.4
In HF, there is a torch funciton mode called TransformGetItemToIndex which internally calls custom autograd function. When this custom autograd function is called under vmap, It triggers CustomFunctionHigherOrderOP which error-ed because there was no pre-dispatch proxy mode implementation.
Since there are number of requests lately to add various operators in pre-dispatch IR, I introduce a decorator in export that works similar to `allow_in_graph`. Basically:
1) We intercept custom_autograd_function.apply at pre-dispatch mode when this decorator is applied
2) We apply `flat_apply` HOP to hide the pytree spec for this autograd function. Note that this adds restriction that this custom autograd function needs to take in fx-able types.
3) subclass constructor decorator is implemented similarly, so we just refactor it to use similar implementation as this new decorator. eventually we should delete the subclass constructor decorator.
4) Move some code in subclass constructor decorator to exit early in non-export environment which should shave off some inefficiency (around 1% according to @swolchok 's benchmark)
Fixes: https://github.com/pytorch/pytorch/issues/161563#issuecomment-3246309758
Differential Revision: [D82141316](https://our.internmc.facebook.com/intern/diff/D82141316)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162240
Approved by: https://github.com/ydwu4
Summary:
Sometimes `ShapeEnv.create_symbol` can return a `sympy.Integer`. This messes up our phantom symbol infra for derived dims.
Fixes#161902
Test Plan:
added test based on repro
Rollback Plan:
Differential Revision: D81960709
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162416
Approved by: https://github.com/tugsbayasgalan
This PR is quite large in that it covers most of rough edges in the new strict export flow:
1. Handle nn_module_stack correctly now that we are tracing wrapper module
2. module_call_spec needs to get queried from source directly because we are not running the bytecode anymore.
3. Correct input and output handling.
@diff-train-skip-merge
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162183
Approved by: https://github.com/zhxchen17
This PR is quite large in that it covers most of rough edges in the new strict export flow:
1. Handle nn_module_stack correctly now that we are tracing wrapper module
2. module_call_spec needs to get queried from source directly because we are not running the bytecode anymore.
3. Correct input and output handling.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162183
Approved by: https://github.com/zhxchen17
ghstack dependencies: #162167
Summary: This PR introduces shape guards to export. Previously only value ranges, equalities, and specializations would be tracked for symbolic expressions, and we had a forward hook to check them. Instead now we create a function to check shape guards and call it in the exported program.
Test Plan:
updated several tests
Rollback Plan:
Differential Revision: D80713603
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161178
Approved by: https://github.com/tugsbayasgalan
Fixes#161080
torch.export.export fails with TypeError: expand() got an unexpected keyword argument 'implicit' when calling torch.expand_copy(..., implicit=True). This happened because expand_copy = _make_copy_from_view(aten.expand) register aten. expand as the decomposition path for aten.expand_copy, which doesn’t accept the implicit argument.
I have added an explicit a decomposition for aten.expand_copy in torch/_decomp/decompositions.py to ignore the implicit argument, and a simple unit test to demonstrate the bug being fixed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161688
Approved by: https://github.com/angelayi, https://github.com/can-gaa-hou
Summary:
[reland]
Since `allow_complex_guards_as_runtime_asserts` is now sync'd with `prefer_deferred_runtime_asserts_over_guards`, we can kill the former (especially since it was a export-only concept).
Test Plan:
updated tests
Rollback Plan:
Differential Revision: D81334984
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161794
Approved by: https://github.com/zhxchen17
Summary:
Our strategy for detecting fake tensor leakage in non-strict for outside scope (side effects happening outside of model.forward) is:
1. We do gc.collect() before export and get the alive fake tensors
2. We dump the proxy to fake tensor map from make_fx tracer
3. We query gc again to get alive fake tensors
4. We take the delta between (1) and (3)
5. Filter out fake tensors that are:
1. Associated with `TrackedFake` (input tracking thing in symbolic_shapes)
2. Associated with `gm.meta`
6. Do ID match with the proxies and emit their stacktraces.
We rely on (https://github.com/pytorch/pytorch/pull/159923) for other sources of leakages such as:
1. We failed to proxy an operator (like param.data)
2. We cache some tensor in model.forward (https://github.com/pytorch/pytorch/issues/155114)
In general, we notice `gc.collect()` and query-ing gc for live objects are kinda slow. So we turn on this feature under env variable. We should document on export public facing documents that if you run into weird errors regarding fake tensors, they should look into turning on this env variable for further analysis.
Test Plan:
Test plan
Rollback Plan:
Differential Revision: D80003204
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160456
Approved by: https://github.com/pianpwk
Summary: Since `allow_complex_guards_as_runtime_asserts` is now sync'd with `prefer_deferred_runtime_asserts_over_guards`, we can kill the former (especially since it was a export-only concept).
Test Plan:
updated tests
Rollback Plan:
Differential Revision: D79903317
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160198
Approved by: https://github.com/ezyang