This makes gemma3 exportable on transformers=4.55.4
In HF, there is a torch funciton mode called TransformGetItemToIndex which internally calls custom autograd function. When this custom autograd function is called under vmap, It triggers CustomFunctionHigherOrderOP which error-ed because there was no pre-dispatch proxy mode implementation.
Since there are number of requests lately to add various operators in pre-dispatch IR, I introduce a decorator in export that works similar to `allow_in_graph`. Basically:
1) We intercept custom_autograd_function.apply at pre-dispatch mode when this decorator is applied
2) We apply `flat_apply` HOP to hide the pytree spec for this autograd function. Note that this adds restriction that this custom autograd function needs to take in fx-able types.
3) subclass constructor decorator is implemented similarly, so we just refactor it to use similar implementation as this new decorator. eventually we should delete the subclass constructor decorator.
4) Move some code in subclass constructor decorator to exit early in non-export environment which should shave off some inefficiency (around 1% according to @swolchok 's benchmark)
Fixes: https://github.com/pytorch/pytorch/issues/161563#issuecomment-3246309758
Differential Revision: [D82141316](https://our.internmc.facebook.com/intern/diff/D82141316)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162240
Approved by: https://github.com/ydwu4
Summary:
A demo for creating AOTI delegate for NativeRT in OSS.
- It supports full graph lowering only.
- It leverages `executorch_call_delegate` HOP but doesn't rely on `executorch`.
- The delegate graph is obtained by tracing a `LoweredBackendModule` whose forward function calls `executorch_call_delegate`.
- The main difference between `executorch_call_delegate` and `aoti_call_delegate` is that the delegate graph from `executorch_call_delegate` doesn't have weights lifted as inputs.
- original_ep and delegate_ep are treated as flat EP dictionary and there is no nested structure.
- The naming contract is enforced by `model_name` and `backend_id`
Test Plan:
CI
Rollback Plan:
Differential Revision: D81641157
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162285
Approved by: https://github.com/dolpm
Summary: ONNX team and recent transformer upgrade ran into this error and we also ran into during our export benchmarking. This diff makes it possible to trace through vmap implementation in pre-dispatch IR. Note that we don't support serializing functorch ops in pre-dispatch IR and in the future, we should desugar them to post-grad ops.
The implementation strategy is:
1. We add python wrappers around vmap APIs so that we attach custom torch function handler that is only on during non-strict export. The reason is we don't want to add this to default torch_function handler because it will break BC.
2. Some dynamo changes to make sure it picks up new python wrapper APIs. The reason is when we do strict export, we need to re-materialize these APIs in pre-dispatch IR from torch IR. We can avoid this by special casing in dynamo for export to proxy different API calls but i feel that is too much chaos because you need to be able to proxy 2 different variants of same vmap API.
Test Plan: CI
Differential Revision: D75623875
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154650
Approved by: https://github.com/ezyang, https://github.com/zou3519
Summary:
In HF model rwkv, we have parameter mutation under inference mode which should be safe. This PR does multiple things to make sure it works:
1. We execute global autograd mutation while tracing so that we can actually trace through parameter inplace mutation
2. Add support for parameter mutation under inference mode in AOTAutograd
3. Add support for parameter mutation under inference mode in export.
Test Plan:
test
Rollback Plan:
Differential Revision: D79460136
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159661
Approved by: https://github.com/ydwu4
lint:
- test/test_fake_tensor.py
- test/test_flop_counter.py
- torch/_export/verifier.py
with same rules as other files, it was a night mare for me to update tests in one of the skipped files
with not being able to lint them locally like other files with lintrunner -a.
note that those file do have active dev and not old not touched files.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154261
Approved by: https://github.com/angelayi, https://github.com/Skylion007
Summary:
att
regular weight has the type of torch.nn.parameter.Parameter
buffer and tensor constant has the type of torch.Tensor
both types are valid.
Test Plan: CI
Differential Revision: D72657275
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150867
Approved by: https://github.com/zhxchen17
Summary:
Previously, aoti compile node is represented as a kernel-less custom op in the exported program. The node was not eager runnable, which is a common practice for numerical validation during lowering.
I introduce a new HOP to address this.
The schema is following
```
aoti_call_delegate(lower_moduel: AOTInductorEPModule, original_gm: fx.GraphModule, weights: List[Tensor], inputs: List[Tensor])
```
There are a few problems exposed by HOP
- AOTI expects a FX graph with weights as getattr nodes, aka stateful graph. HOP expect graph_module arguments to be stateless. Export serializer also expect a stateless graph. Currently, to make AOTI happy, I am making `original_gm` stateful, and bypassing the serialization for `original_gm`.
- As a result, the HOP is not re-traceable, as functionalization on stateful graph module argument will fail.
Test Plan: buck2 test 'fbcode//mode/opt' fbcode//deeplearning/aot_inductor/cpu/test:cpu_lowering_utils_test
Reviewed By: zhxchen17
Differential Revision: D68359391
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145630
Approved by: https://github.com/zou3519
# Why?
I want the following code to work.
minimal repro:
```
class M(torch.nn.Module):
def forward(self, dilate_flag):
return dilate_flag.item()
input1 = (torch.tensor([1], dtype=torch.bool, device="cuda"),)
model = M().cuda()
ep = torch.export.export(model, input1, strict=True)
path = torch._inductor.aot_compile(ep.module(), input1)
aot_model = torch._export.aot_load(path, device="cuda")
actual_output = aot_model(*input1)
```
error: AssertionError: Encountered an unsupported object of type <class 'torch.SymBool'> while writing the metadata for exported program
second error will be handled by https://github.com/pytorch/pytorch/pull/138760
# Motivation
I could technically bypass it with a torch.int tensor. However, it doesn't work with torch.cond. I want the following to work. It would also require https://github.com/pytorch/pytorch/pull/138760 for aot compile to work.
```
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.dilate_flag = 0
def forward(self, dilate_flag):
self.dilate_flag = dilate_flag.item()
def true_fn(dilate_flag):
return dilate_flag.clone()
def false_fn(dilate_flag):
return dilate_flag.clone()
torch.cond(
self.dilate_flag,
true_fn,
false_fn,
(dilate_flag,),
)
return self.dilate_flag
input1 = (torch.tensor([1], dtype=torch.bool, device="cuda"),)
input2 = (torch.tensor([0], dtype=torch.bool, device="cuda"),)
inputs = (input1, input2)
model = M().cuda()
for input in inputs:
expected_output = model(*input)
ep = torch.export.export(model, input, strict=False)
path = torch._inductor.aot_compile(ep.module(), input)
aot_model = torch._export.aot_load(path, device="cuda")
actual_output = aot_model(*input)
assert (
expected_output == actual_output
), f"henry they are not equal {expected_output} != {actual_output}"
```
Differential Revision: D64867504
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138765
Approved by: https://github.com/ydwu4
In this diff, i make test_torchbind.py tests to handle training IR. Today in the training IR, we don't see the effect token and HOP because this happens at the FunctionalTensorMode. Maybe in the future, we should move this logic up to the training IR so that writing passes etc on training Ir is safer. But for the migration purposes, i think it is ok for now. I also fixed two bugs:
1. ep.module() doesn't register all aliased constants in the module.
2. When we retrace, we need to fakify the original Torchbind object.
3. We don't run any DCE on training IR so we need to add some more torch ops to verifier.
Differential Revision: [D64853530](https://our.internmc.facebook.com/intern/diff/D64853530)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138658
Approved by: https://github.com/ydwu4, https://github.com/zhxchen17
Summary: When we are placing nodes in the graph, we should also replace the references in module_call_graph.
Test Plan:
buck2 run 'fbcode//mode/opt' torchrec/fb/ir/tests:test_serializer -- --filter-regex test_serialize_deserialize_vlea
buck2 test 'fbcode//mode/opt' fbcode//torchrec/fb/ir/tests:test_serializer -- --exact 'torchrec/fb/ir/tests:test_serializer - torchrec.fb.ir.tests.test_serializer.TestSerializer: test_serialize_empty_value_vlea' --run-disabled
buck2 test 'fbcode//mode/opt' fbcode//torchrec/fb/ir/tests:test_serializer -- --exact 'torchrec/fb/ir/tests:test_serializer - torchrec.fb.ir.tests.test_serializer.TestSerializer: test_deserialized_device_vle' --run-disabled
Differential Revision: D62014035
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134830
Approved by: https://github.com/angelayi
Some sympy Functions aren't supported by sympy_interp(); we can't turn them into FX nodes, so currently the runtime asserts CSE pass avoids CSE'ing on any expression containing a sympy Function. https://github.com/pytorch/pytorch/pull/132325 started tracking unsupported functions, so we switch the check to that to be more precise. We also check for and skip unsupported functions when adding asserts - previously we only did the check for CSE, and not adding new expressions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132457
Approved by: https://github.com/avikchaudhuri
Summary: Finishing up the mechanism to "register" certain types of operators to a registry so that the serializer can handle them correctly. This is expected to be firstly used by executorch.
Test Plan: buck run mode/opt caffe2/test:test_export -- -r test_export_with_extension_op_serialization
Differential Revision: D59825148
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130851
Approved by: https://github.com/angelayi
Summary: This diff updates the ExportedProgram class in PyTorch to allow for multiple verifiers to be attached to it. This is done by adding a new field to the ExportedProgram schema called "verifiers" which is a list of strings representing the names of the verifiers to be attached to the program. The verifiers are loaded using the "load_verifier" function which is defined in the "torch._export.serde.serialize" module. The "exported_program.dialect" field is also deprecated in favor of the "verifiers" field.
Test Plan: CI
Differential Revision: D59408546
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130364
Approved by: https://github.com/angelayi, https://github.com/ydwu4
At a high level, the idea behind this PR is:
* Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.)
* Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers.
The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions:
* FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing).
* ModularIndexing, LShift, RShift now assert they are given integer inputs.
* Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver
* TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division.
* Trunc is split to TruncToFloat and TruncToInt.
* Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result.
* RoundDecimal updated to consistently only ever return a float
* Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing)
In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information.
We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**:
* `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy
* `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv`
These changes have consequences. First, we need to make some administrative changes:
* Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2)
* Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py**
* In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function
* TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here
* Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet
* Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions.
In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments:
* Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now
* `_assert_bound_is_rational` is no more, we no longer generate rational bounds
* Don't intersect non-int value ranges with the `int_range`
* Support more sympy Functions for guard SYMPY_INTERP
* Assert the type of value range is consistent with the variable type
The new asserts uncovered necessary bug fixes:
* **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions
* **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions
* **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr!
* **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1
Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py**
**Reland notes.** This requires this internal fbcode diff https://www.internalfb.com/phabricator/paste/view/P1403322587 but I cannot prepare the diff codev due to https://fb.workplace.com/groups/osssupport/posts/26343544518600814/
It also requires this Executorch PR https://github.com/pytorch/executorch/pull/3911 but the ET PR can be landed prior to this landing.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905
Approved by: https://github.com/xadupre, https://github.com/lezcano
This PR adds a registration function and a global registry for GraphModuleSerializer. After this PR, custom serialization methods can be done through registration instead of subclassing for ease of maintenance.
## Changes
- Add a test case where it injects custom op to test serialization.
- Add custom op handler
- Change allowed op for verifier
Co-authored-by: Zhengxu Chen <zhxchen17@outlook.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126550
Approved by: https://github.com/zhxchen17
This PR adds a new metadata, `torch_fn` which is meant to replace `source_fn_stack` as `source_fn_stack` is not entirely well defined between strict/nonstrict. Previous discussion [here](https://docs.google.com/document/d/1sPmmsmh6rZFWH03QBOe49MaXrQkP8SxoG8AOMb-pFk4/edit#heading=h.anmx9qknhvm).
`torch_fn` represents the torch function that a particular aten operator came from. For example, `torch.nn.Linear` goes down to the `torch.nn.functional.linear` at the `__torch_function__` layer, and then `aten.t/aten.addmm` in the `__torch_dispatch__` layer. So the nodes `aten.t/aten.addmm` will now have the `torch_fn` metadata containing the `torch.nn.functional.linear`.
The `torch_fn` metadata is a tuple of 2 strings: a unique identifier for each torch function call, and the actual torch function `f"{fn.__class__}.{fn.__name__}"`. The purpose of the first value is to distinguish between 2 consecutive calls to the same function. For example, if we had 2 calls to `torch.nn.Linear`, the nodes and corresponding metadata would look something like:
```
aten.t - ("linear_1", "builtin_function_or_method.linear"),
aten.addmm - ("linear_1", "builtin_function_or_method.linear"),
aten.t - ("linear_2", "builtin_function_or_method.linear"),
aten.addmm - ("linear_2", "builtin_function_or_method.linear"),
```
Higher order ops -- currently we can get the torch_fn metadata for nodes within the HOO's subgraph, but after retracing, this becomes the `(cond, higher_order_op.cond)` :( This is because `fx_traceback.set_current_meta` points to the cond node in the toplevel graph, rather than the original node in the subgraph. I think this is because `fx.Interpreter` does not go into the cond subgraphs. (will discuss with Yidi more ab this)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122693
Approved by: https://github.com/tugsbayasgalan
We would like to improve consistency for nn_module_stack metadata in torch.export.
This PR ensures that all tests in test/export/test_export.py has the following constraints:
- Remove nn_module_stack for all placeholder & output nodes, for all modules and submodules
- Ensure nn_module_stack is present for all other node types for the top-level module (there is still an issue with torch.cond submodules having empty fields)
- Add these checks to _export() in _trace.py (we would add this in the Verifier, but downstream apps construct ExportedPrograms separate from _export(), and metadata may not be maintained there)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120661
Approved by: https://github.com/avikchaudhuri
Summary:
X-link: https://github.com/pytorch/executorch/pull/1817
Basic support for non-persistent buffers, which are buffers that do not show up in the state dict.
One weird twist is that most of our other systems (FX, aot_export, dynamo) have completely buggy handling of non-persistent buffers. I tried to go on a wild goose chase to fix them all, but it got to be too much. So I introduced some sad rewrite passes in `_export` make the final state dict correctly align with the original module's state dict.
This exposed some bugs/ambiguous handling of parameters/buffers in existing test code. For example, `TestSaveLoad.test_save_buffer` traced over a module that was not in the root module hierarchy and caused some weird behavior. I think we should error explicitly on use cases like this: https://github.com/pytorch/pytorch/issues/118410. For now I just rewrote the tests or skipped them.
As a side effect, this diff tightened up quite a few sloppy behaviors around state dict handling:
- Tensor attributes were getting promoted to be buffers—bad!
- Tracing through a module not in the children of the root module would add its parameters/buffers to the state dict—bad!
This behavior is unlikely to show up in user code since the model would be totally broken, but did show up in a bunch of tests.
#buildmore
Test Plan:
unit tests
sandcastle
Differential Revision: D53340041
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118969
Approved by: https://github.com/guangy10, https://github.com/huydhn, https://github.com/titaiwangms
Summary:
X-link: https://github.com/pytorch/executorch/pull/1769
Basic support for non-persistent buffers, which are buffers that do not show up in the state dict.
One weird twist is that most of our other systems (FX, aot_export, dynamo) have completely buggy handling of non-persistent buffers. I tried to go on a wild goose chase to fix them all, but it got to be too much. So I introduced some sad rewrite passes in `_export` make the final state dict correctly align with the original module's state dict.
This exposed some bugs/ambiguous handling of parameters/buffers in existing test code. For example, `TestSaveLoad.test_save_buffer` traced over a module that was not in the root module hierarchy and caused some weird behavior. I think we should error explicitly on use cases like this: https://github.com/pytorch/pytorch/issues/118410. For now I just rewrote the tests or skipped them.
Test Plan: added a unit test
Differential Revision: D53253905
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118722
Approved by: https://github.com/SherlockNoMad, https://github.com/angelayi
Summary:
We used to skip verifier when the signature object is not the "correct" one (usually from some deprecated frontend). This was very useful when we wanted to pay a small cost to enable verifier path to be called everywhere for torch export.
Now I believe no tests are relying on this behavior so we should remove this weird branch.
Test Plan: CI
Differential Revision: D53024506
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118139
Approved by: https://github.com/suo
Added support for constant outputs. We will just embed the constant directly into the output, like `return (x, 1)`.
Also adds support for None input/outputs. For None inputs we address it the same way we do to constants, which is that a placeholder with no users will be inserted into the graph, and the None will be embedded into whatever operator is using the None. For None outputs, we will also address the same way we do constants, which is that we embed it into the output, like `return (x, None)`.
Differential Revision: D52881070
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117894
Approved by: https://github.com/zhxchen17
Summary:
We intend to preserve autograd ops for predispatch export. Therefore, we
need to exempt the autograd ops in some places, e.g. verifier and
proxy_tensor.py.
Test Plan:
python test/export/test_export.py -k test_predispatch_export_with_autograd_op
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116527
Approved by: https://github.com/tugsbayasgalan
ghstack dependencies: #116339