Summary: WrapperModule seems a good idea but may introduce some surprising behavior to users, for example, it never registers enclosed modules as submodules and therefore it's unclear that's the state dict for the exported program should look like, because some people may argue to include every state in state dict but others want to keep them as constants.
Test Plan: CI
Reviewed By: tugsbayasgalan
Differential Revision: D54326331
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121042
Approved by: https://github.com/angelayi
This PR makes the tests for inline and sequential_split stop relying on set_grad_enabled to be in the graph. Because they'll be gone if we turn on the replace_set_grad_with_hop_pass in the following diff. Instead, we'll manually insert them into the graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119914
Approved by: https://github.com/tugsbayasgalan
ghstack dependencies: #119732, #119736, #119810, #119913
As titled. Before the PR, after we split then inline_, there will be getitem calls in the graph while the original graph module doesn't have them. This PR removes the additional get_item calls by inlining.
Test Plan:
Added new test cases for graphs that return multiple outputs and takes multiple inputs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119913
Approved by: https://github.com/tugsbayasgalan
ghstack dependencies: #119732, #119736, #119810
This pr is the 1/N pr of transforming the global state mutating ops such as torch._C.set_grad_enabled calls in pre-dispatch graph into a higher order op so that the graph becomes more functional. We make use of split_module to help us do the transformation.
This pr preserves the node.name in original module by adding a new kwarg `keep_original_node_name` to split_module.
For a graph looks like this:
```python
def forward(self, arg_0):
arg0_1, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
sin = torch.ops.aten.sin.default(add); add = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
_set_grad_enabled = torch._C._set_grad_enabled(False)
add_1 = torch.ops.aten.add.Tensor(sum_1, 1); sum_1 = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True)
sub = torch.ops.aten.sub.Tensor(add_1, 1)
return pytree.tree_unflatten((add_1, sub), self._out_spec)
```
Before the change, split graph returns the following graphs and subgraphs (notice the change from `add` -> `add_tensor`, `sin` -> `sin_default`:
```python
def forward(self, arg_0):
arg0_1, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
submod_0 = self.submod_0(arg0_1); arg0_1 = None
submod_1 = self.submod_1(submod_0); submod_0 = None
submod_2 = self.submod_2(submod_1)
return pytree.tree_unflatten((submod_1, submod_2), self._out_spec)
# submod_0
def forward(self, arg0_1):
add_tensor = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
sin_default = torch.ops.aten.sin.default(add_tensor); add_tensor = None
sum_default = torch.ops.aten.sum.default(sin_default); sin_default = None
return sum_default
# submod_1
def forward(self, sum_1):
_set_grad_enabled = torch._C._set_grad_enabled(False)
add_tensor = torch.ops.aten.add.Tensor(sum_1, 1); sum_1 = None
return add_tensor
# submod_2
def forward(self, add_1):
_set_grad_enabled = torch._C._set_grad_enabled(True)
sub_tensor = torch.ops.aten.sub.Tensor(add_1, 1); add_1 = None
return sub_tensor
""")
```
After the change, the test produce the following graph, all the node names in original graph module are preserved in sub_modules.
```python
def forward(self, arg_0):
sub, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
submod_0 = self.submod_0(sub); sub = None
submod_1 = self.submod_1(submod_0); submod_0 = None
submod_2 = self.submod_2(submod_1)
return pytree.tree_unflatten((submod_1, submod_2), self._out_spec)
# submod_0
def forward(self, arg0_1):
add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
sin = torch.ops.aten.sin.default(add); add = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
return sum_1
# submod_1
def forward(self, sum_1):
_set_grad_enabled = torch._C._set_grad_enabled(False)
add_1 = torch.ops.aten.add.Tensor(sum_1, 1); sum_1 = None
return add_1
# submod_2
def forward(self, add_1):
_set_grad_enabled_1 = torch._C._set_grad_enabled(True)
sub = torch.ops.aten.sub.Tensor(add_1, 1); add_1 = None
return sub
```
Note that currently, we call split_module on the graph after pre-dispatch aot. The difference is even larger if we `split_module` the graph module produced by dynamo, where all the original variables names in user program are preserved after dynamo but lost after `split_module` without this change.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119732
Approved by: https://github.com/tugsbayasgalan
Summary:
X-link: https://github.com/pytorch/executorch/pull/1817
Basic support for non-persistent buffers, which are buffers that do not show up in the state dict.
One weird twist is that most of our other systems (FX, aot_export, dynamo) have completely buggy handling of non-persistent buffers. I tried to go on a wild goose chase to fix them all, but it got to be too much. So I introduced some sad rewrite passes in `_export` make the final state dict correctly align with the original module's state dict.
This exposed some bugs/ambiguous handling of parameters/buffers in existing test code. For example, `TestSaveLoad.test_save_buffer` traced over a module that was not in the root module hierarchy and caused some weird behavior. I think we should error explicitly on use cases like this: https://github.com/pytorch/pytorch/issues/118410. For now I just rewrote the tests or skipped them.
As a side effect, this diff tightened up quite a few sloppy behaviors around state dict handling:
- Tensor attributes were getting promoted to be buffers—bad!
- Tracing through a module not in the children of the root module would add its parameters/buffers to the state dict—bad!
This behavior is unlikely to show up in user code since the model would be totally broken, but did show up in a bunch of tests.
#buildmore
Test Plan:
unit tests
sandcastle
Differential Revision: D53340041
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118969
Approved by: https://github.com/guangy10, https://github.com/huydhn, https://github.com/titaiwangms
Summary:
X-link: https://github.com/pytorch/executorch/pull/1769
Basic support for non-persistent buffers, which are buffers that do not show up in the state dict.
One weird twist is that most of our other systems (FX, aot_export, dynamo) have completely buggy handling of non-persistent buffers. I tried to go on a wild goose chase to fix them all, but it got to be too much. So I introduced some sad rewrite passes in `_export` make the final state dict correctly align with the original module's state dict.
This exposed some bugs/ambiguous handling of parameters/buffers in existing test code. For example, `TestSaveLoad.test_save_buffer` traced over a module that was not in the root module hierarchy and caused some weird behavior. I think we should error explicitly on use cases like this: https://github.com/pytorch/pytorch/issues/118410. For now I just rewrote the tests or skipped them.
Test Plan: added a unit test
Differential Revision: D53253905
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118722
Approved by: https://github.com/SherlockNoMad, https://github.com/angelayi
This PR rewrites two paths to use the newly-added keypaths API in pytree:
First: we were hand-rolling a tree_map during fakification because we wanted to track sources. This PR uses keypaths instead, which can do the same thing without needing custom code.
Second: our constraint error formatting was referencing placeholder names in error messages. These placeholder names are not otherwise user-visible, so they are super confusing to users (e.g. "which input does arg1_3 correspond to?"). This diff uses the `keystr` API to format the error message.
This necessitated some small refactors—generating the keystr is expensive so doing it in an f-string was very bad.
It can also be further improved—we can inspect the signature so that instead of `*args[0]` we can give people the actual argument name, which would be the ideal UX. But leaving that for later.
Differential Revision: [D53139358](https://our.internmc.facebook.com/intern/diff/D53139358/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118609
Approved by: https://github.com/zhxchen17
ghstack dependencies: #118607, #118608
Summary:
Currently we have a very ugly specialization on edge dialect in verifier like the following:
```
# TODO Remove this branch.
if ep.dialect == "EDGE": # !!! Don't change this allowlist. !!!
pass
else:
raise e
```
In this diff we do some additional work to make signature checking also work in exir. We decouple the transformation stack in torch export and exir so that different layers of the stack can evolve in their own fashion and the team can divide and conquer them seperately.
Test Plan: CI
Differential Revision: D52499225
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116705
Approved by: https://github.com/tugsbayasgalan
Previously we were generating a graph to add runtime assertions on inputs and then running that graph to check input constraints. This PR checks input constraints directly.
Differential Revision: D50289970
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111262
Approved by: https://github.com/zhxchen17
Recently we updated the `export` API to take an experimental `dynamic_shapes` argument that was meant to subsume the existing `constraints` argument.
This PR deprecates `constraints` (with a warning on its use, but without actually removing it). Simultaneously it replaces all uses of `constraints` in docs, examples, and tests with corresponding uses of `dynamic_shapes` (preserving behavior). This exercise fortunately revealed some minor bugs in the implementation which have also been fixed in this PR.
Some uses of `constraints` still remain, e.g., when `torch._dynamo.export` is called directly. (Meta-internal uses will be updated in a separate diff.)
Differential Revision: D49676049
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110143
Approved by: https://github.com/tugsbayasgalan
Some notable changes:
1. `constrain_as_size` allows min value to be less than 2 as it will unconditionally assume min >= 2 for compiler purposes. Instead, we add additional check to make sure max value is always greater than 2.
2. Previously, we used to runtime assert on the unbacked symint's val range which would be always between [2, max]. I modified this logic to assert on [0, max] unless user explicitly specifies the min range.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106591
Approved by: https://github.com/gmagogsfm, https://github.com/ezyang
Update graph_signature according to graph after transformation.
Transformations can lead to node name changes, which are used in
graph_signature to identify inputs and outputs. Therefore, after each
transformation, we need to update the graph_signature according to
new node names.
WARNING: This implementation makes a few assumptions
- The transformation doesn't change number of inputs/outputs
- Each input/output still has the same meaning.
- For inputs, that means that the inputs in transformed
graph map to the same lifted parameter/buffer or user
input as the input of the same position in the graph
before transformation.
- Similarly for outputs, each output should correspond to the
same mutated buffer or user output as the output value of
the same position in the graph before transformation.
It is difficult to programatically validate these assumptions, but they
should hold true most of the time as inputs/outputs of the graph rarely
need to be changed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107080
Approved by: https://github.com/tugsbayasgalan
Some notable changes:
1. `constrain_as_size` allows min value to be less than 2 as it will unconditionally assume min >= 2 for compiler purposes. Instead, we add additional check to make sure max value is always greater than 2.
2. Previously, we used to runtime assert on the unbacked symint's val range which would be always between [2, max]. I modified this logic to assert on [0, max] unless user explicitly specifies the min range.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106591
Approved by: https://github.com/gmagogsfm, https://github.com/ezyang
This PR:
* Address comment at https://github.com/pytorch/pytorch/pull/103887/files#r1244128266.
* Add test for graph partition to make sure assertion ops functionalization won't break graph partition in unexpected way.
**NOTE**:
In the context of export, it's totally up to the user to any type of graph partition based on specific use case. It's hard to anticipate the concrete downstream use case nor provide any specific functionality to facilitate handling assertion ops (functional / non-functional). So this PR limit to itself to [`CapabilityBasedPartitioner`](2da6cae43c/torch/fx/passes/infra/partitioner.py (L34)) and make sure it doesn't break graph partition unexpectedly (by adding some test).
For the test case used in PR, a few things to highlight:
* Without assertion, the fused graph is roughly like:
```
class fused(torch.nn.Module):
def forward(self, a, b):
fused_1 = self.fused_1(a, b);
relu = fused_1.relu()
fused_0 = self.fused_0(fused_1, relu)
return (fused_0, fused_1)
class fused_0(torch.nn.Module):
def forward(self, add_2, relu):
... # Logic after relu
return add_4
class fused_1(torch.nn.Module):
def forward(self, a, b):
... # Logic before relu, `add_1` is only exposed within this submodule.
return add_2
```
* With the assertion, the fused graph is roughly like:
```
class fused(torch.nn.Module):
def forward(self, arg0_1: i64[s0], arg1_1: i64[s0]):
dep_token0 = ...
...
fused_1 = self.fused_1(arg0_1, arg1_1); arg0_1 = arg1_1 = None
...
getitem: i64[s0] = fused_1[0] # `getitem` is actually `add_1`
...
relu_default: i64[s0] = torch.ops.aten.relu.default(getitem_1)
...
# For inline assertion. Note that `getitem` which is an output of `fused_1`, is consumed by it.
select_int: i64[] = torch.ops.aten.select.int(getitem, 0, 0)
eq_scalar: b8[] = torch.ops.aten.eq.Scalar(select_int, 5)
dep_token2: f32[] = torch.ops.aten._functional_assert_async.msg(
eq_scalar, 'assertion error', dep_token = dep_token1
)
...
getitem_1: i64[s0] = fused_1[1] # `getitem_1` is actually `add_2`
fused_0: i64[s0] = self.fused_0(getitem_1, relu_default)
...
return (fused_0, getitem_1, dep_token2)
class fused_0(torch.nn.Module):
def forward(self, add_tensor_2: i64[s0], relu_default: i64[s0]):
... # Logic after relu
return add_tensor_4
class fused_1(torch.nn.Module):
def forward(self, arg0_1: i64[s0], arg1_1: i64[s0]):
... # Logic before relu
# `add_tensor_1` (basically `add_1`) is returned to allow downstream assertion op consumes it.
return (add_tensor_1, add_tensor_2)
```
As shown above, the extra assertion added (actually regardless whether it's funtionalized or not), it **won't** case extra submodule breakage if the asserted node is an intermediate node within the submodule - here the intermediate node will be returned as extra output of submodule so downstream assertion node can consume it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104287
Approved by: https://github.com/tugsbayasgalan
This PR integrated the assertion functionalization logic into current export logic.
**NOTE:**
I finally decided to do the assertion functionalization after AOT export instead of before for the following reasons:
* The benefit of AOT export is that the graph is already functionalized so things like method call is already transformed to function call. However, if we do it before AOT export, the graph is still in torch level and extra logic like bab21d20eb/torch/_export/pass_base.py (L201-L204C17) will need to be implemented.
* The graph signature is kind of already incorrect after adding runtime assertions currently (this doesn't seem break logic since we already depend on positions instead of FQNs of outputs). This PR also fixed this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103887
Approved by: https://github.com/avikchaudhuri, https://github.com/tugsbayasgalan
The idea here is to create do a graph mutation to:
* Create an initial dependency token at the beginning of the program.
* Replace non-functional version of assertion statements to functional version.
* The functional version of assertion statement will:
* Accept a dependency token from output of previous functional assertion statement (or the initial dependency token if there isn't any).
* Generate a dependency token as the output of assertion statement.
* Augment the output to include the dependency token generated by last assertion statement.
The goal here is to:
* Form an explicit dependency chain and avoid potential reordering during other passes of compiling.
* Make the assertions a part of overall execution graph will affect the final output (or it could potentially be DCEed).
**NOTE:**
* Currently only cover `contrain_range` and WIP to support other assertions. Send out this PR to collect feedback first.
* Here it only focus on implementation itself. Will integrate it with current export in future PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103757
Approved by: https://github.com/avikchaudhuri
At high current implementation of constrains functions (constrain_as_**) will raise exception for the following code snippets:
```
def f(x):
a = x.item()
constrain_as_size(a, 4, 7)
return torch.empty((a, 4))
inp = torch.tensor([5])
ep = torch._export.export(f, (inp,))
```
The reason is because current constrain logic is:
1) Purely python so it won't survive AOT export (the full node is gone after AOT export since AOT export only maintains aten level op).
2) Utilize side effect to add range constraints for traced symbol's shape env ([code](9591e52880/torch/fx/experimental/symbolic_shapes.py (L370-L372))).
3) If runtime assertion is turned on (by default). [`_AddRuntimeAssertionsForConstraintsPass`](9591e52880/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py (L98-L100)) will try to append assertion node based on range constrains extracted from shape env of symbol during another interpretation round.
4). However, since 1), in the round of AOT export, range constraints logic won't run for symbols generated during this round. And later there is no range constrains information available for assertion round and caused issue.
5) As a result of above, it will failure at `torch.empty((a, 4))` (there is no constrains for `a` that it must be positive).
The fix here is just to implement range constrain logic as a native aten op (CPU implementation as no-op) to make it be able to survive AOT export.
**NOTE:**
[Logic](2d745b95d7/torch/fx/experimental/symbolic_shapes.py (L350-L365C15)) within [`constrain_range`](2d745b95d7/torch/fx/experimental/symbolic_shapes.py (LL313C74-L313C74)) is split out as `constrain_range_int` to capture case when non `SymInt` is passed in and reused in the new `_constrain_range`. The reason is when non `SymInt` is provided:
* If it directly calls `sym_constrain_range`, the C++ version will be called which will be no-op.
* So in this case it calls `constrain_range_int` instead to be able to capture issue like user provides a input whose tensor's shape could be out of range during exporting, like the following for above code example:
```
...
inp = torch.tensor([10])
ep = torch._export.export(f, (inp,)) # immediately raise error
```
Differential Revision: [D46734204](https://our.internmc.facebook.com/intern/diff/D46734204)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103346
Approved by: https://github.com/tugsbayasgalan
This PR adds aot_export_module as the lowering path from torch.level graph to aten graph. Some known limitations that need to be addressed in the follow up PRs:
1. Store param/buffer data in ExportedProgram
2. Fully support torch.cond with params/buffers
3. Making sure no duplicated ExportMetaData entry
4. This API will break Executorch if used on PyE, we will figure out a plan internally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101490
Approved by: https://github.com/avikchaudhuri
Previously we had runtime asserts for range constraints. This diff adds runtime asserts for equality constraints.
This requires a bit of refactoring that is worth calling out.
1. [Minor] Some of the data structures produced by export and consumed by the runtime assertion pass need to be broadened. This is a WIP. There are some associated code improvements that are included in this diff, but by and large the structures are similar to what exists now. Meanwhile @angelayi and I are chatting about how to make it qualitatively better: briefly, we want to index everything by symbols, which are 1-1 with (name, dim) pairs.
2. [Major] The order in which runtime asserts are emitted is changed. Previously we used to do the work in `placeholder`, now this diff adds a hook for "post-processing" after processing of all placeholders is done. This is needed because equality constraints can mention different placeholders. This change also opens the way to optimizing codegen: e.g., each (name, dim) pair should correspond to a single intermediate variable that is reused across runtime asserts. This is future work.
Differential Revision: [D46177642](https://our.internmc.facebook.com/intern/diff/D46177642/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102256
Approved by: https://github.com/tugsbayasgalan, https://github.com/angelayi
This pr does the following:
1. previously, inline constraints is not properly set for tensor output data-dependent ops such as a.nonzero because of its return value is not symint. This pr just uses all the unbacked symbols i.e.those start with "i"/"f" in create_unbacked_sym* functions. Note that these symbols are guaranteed to be a super set of inline user constraints.
2. add inline assertions support by checking.
Currently, it only deal with tensor, SymInt, SymFloat, SymBool output data-dependent ops and ignore the rest. It's good enough for now as we only have a limited number of data-dependent ops (.item and .nonzero are explicitly tested).
The examples for graph that is added assertions is shown below:
```
class ExportGraphModule(torch.nn.Module):
def forward(self, x):
arg0: i64[s0], = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
nonzero_default: i64[i0, 1] = torch.ops.aten.nonzero.default(arg0); arg0 = None
return pytree.tree_unflatten([nonzero_default], self._out_spec)
class GraphModule(torch.nn.Module):
def forward(self, x):
arg0: i64[s0], = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
sym_size: Sym(s0) = torch.ops.aten.sym_size(arg0, 0)
nonzero_default: i64[i1, 1] = torch.ops.aten.nonzero.default(arg0); arg0 = None
sym_size_1: Sym(i1) = torch.ops.aten.sym_size(nonzero_default, 0)
ge: Sym(i1 >= 3) = sym_size_1 >= 3
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(ge); ge = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'nonzero_default.shape[0] is outside of inline constraint [3, 5].'); scalar_tensor_default = None
le: Sym(i1 <= 5) = sym_size_1 <= 5; sym_size_1 = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(le); le = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'nonzero_default.shape[0] is outside of inline constraint [3, 5].'); scalar_tensor_default_1 = None
return pytree.tree_unflatten([nonzero_default], self._out_spec)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100763
Approved by: https://github.com/tugsbayasgalan
I ported over the code for the inline interpreter incorrectly in the pass base 😅
Originally the function `make_inline_interpreter` is supposed to take in a fx.Interpreter type but I accidentally passed in an fx.Interpreter object. Also realized while modifying this diff (and comments from Tugsuu) that we don't really need this InlineInterpreter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100836
Approved by: https://github.com/zhxchen17, https://github.com/tugsbayasgalan
* Added ExportPassBase, an interpreter based helper pass writing class
* It can also help maintain the dialect based on the operator namespace through having users override the `get_valid_dialects` function (returning an empty lists implies the pass works for any dialect).
* Added a `ReplaceBrokenOpsWithFunctionalOpsPass` to replace all ops that have not been converted with functionalization with their functional ones.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100000
Approved by: https://github.com/gmagogsfm