Commit Graph

31 Commits

Author SHA1 Message Date
Zhengxu Chen
138fafe72d [export] Fix torch.export() issues for server use cases. (#108275)
Test Plan: In D48788843

Differential Revision: D48811793

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108275
Approved by: https://github.com/tugsbayasgalan
2023-08-31 07:19:18 +00:00
gmagogsfm
9af0e47653 Hide transform method by renaming it (#107940)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107940
Approved by: https://github.com/tugsbayasgalan
2023-08-25 16:31:44 +00:00
Angela Yi
92f6454ff8 [export][reland] ExportedProgram.transform updates graph_signature automatically (#107792)
Summary: Reland of https://github.com/pytorch/pytorch/pull/107080

Test Plan: CI

Differential Revision: D48533622

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107792
Approved by: https://github.com/gmagogsfm
2023-08-23 22:16:56 +00:00
eellison
c88775b937 Make Nd tensors hit fused addmm pass (#106911)
Replace https://github.com/pytorch/pytorch/pull/106433 since I had a bad cla commit.

Speeds up eager convnext bfloat16 inference by 35%., and eager timm bfloat16 inference average by `.5%`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106911
Approved by: https://github.com/ezyang
2023-08-16 17:12:11 +00:00
PyTorch MergeBot
b860c8c5b8 Revert "ExportedProgram.transform now updates graph_signature automatically (#107080)"
This reverts commit 8c9b2fe8f0.

Reverted https://github.com/pytorch/pytorch/pull/107080 on behalf of https://github.com/izaitsevfb due to Breaks executorch tests, see D48333170 ([comment](https://github.com/pytorch/pytorch/pull/107080#issuecomment-1679588292))
2023-08-15 20:47:35 +00:00
Tugsbayasgalan Manlaibaatar
20c5add133 [export] Refactor constrain_as_value and constrain_as_size (#106591)
Some notable changes:
1. `constrain_as_size` allows min value to be less than 2 as it will unconditionally assume min >= 2 for compiler purposes. Instead, we add additional check to make sure max value is always greater than 2.
2. Previously, we used to runtime assert on the unbacked symint's val range which would be always between [2, max]. I modified this logic to assert on [0, max] unless user explicitly specifies the min range.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106591
Approved by: https://github.com/gmagogsfm, https://github.com/ezyang
2023-08-15 05:41:43 +00:00
gmagogsfm
8c9b2fe8f0 ExportedProgram.transform now updates graph_signature automatically (#107080)
Update graph_signature according to graph after transformation.
            Transformations can lead to node name changes, which are used in
            graph_signature to identify inputs and outputs. Therefore, after each
            transformation, we need to update the graph_signature according to
            new node names.
            WARNING: This implementation makes a few assumptions
                - The transformation doesn't change number of inputs/outputs
                - Each input/output still has the same meaning.
                    - For inputs, that means that the inputs in transformed
                        graph map to the same lifted parameter/buffer or user
                        input as the input of the same position in the graph
                        before transformation.
                    - Similarly for outputs, each output should correspond to the
                        same mutated buffer or user output as the output value of
                        the same position  in the graph before transformation.
            It is difficult to programatically validate these assumptions, but they
            should hold true most of the time as inputs/outputs of the graph rarely
            need to be changed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107080
Approved by: https://github.com/tugsbayasgalan
2023-08-14 19:52:41 +00:00
gmagogsfm
f26aa2dcd9 Keep fx node name consistent with aot_export (#107068)
torch.export() starts initially with node names in aot_export, if we don't make this change, any no-op transformation would break name consistency, thus breaking GraphSignature correctness.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107068
Approved by: https://github.com/tugsbayasgalan
2023-08-12 23:12:03 +00:00
PyTorch MergeBot
745d29b0cc Revert "[export] Refactor constrain_as_value and constrain_as_size (#106591)"
This reverts commit 18989890bf.

Reverted https://github.com/pytorch/pytorch/pull/106591 on behalf of https://github.com/izaitsevfb due to Breaks inductor test on trunk ([comment](https://github.com/pytorch/pytorch/pull/106591#issuecomment-1675069091))
2023-08-11 16:37:47 +00:00
Tugsbayasgalan Manlaibaatar
18989890bf [export] Refactor constrain_as_value and constrain_as_size (#106591)
Some notable changes:
1. `constrain_as_size` allows min value to be less than 2 as it will unconditionally assume min >= 2 for compiler purposes. Instead, we add additional check to make sure max value is always greater than 2.
2. Previously, we used to runtime assert on the unbacked symint's val range which would be always between [2, max]. I modified this logic to assert on [0, max] unless user explicitly specifies the min range.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106591
Approved by: https://github.com/gmagogsfm, https://github.com/ezyang
2023-08-11 05:29:22 +00:00
Zhengxu Chen
2dbadd1eae [export] Remove experimental runtime assertion configs from export API. (#105043)
Test Plan: CI

Differential Revision: D47390794

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105043
Approved by: https://github.com/larryliu0820
2023-07-26 16:21:29 +00:00
xuanqi
3707fbf63b [RFC]: Add test for graph partition after assertion ops functionalization. (#104287)
This PR:
* Address comment at https://github.com/pytorch/pytorch/pull/103887/files#r1244128266.
* Add test for graph partition to make sure assertion ops functionalization won't break graph partition in unexpected way.

**NOTE**:
In the context of export, it's totally up to the user to any type of graph partition based on specific use case. It's hard to anticipate the concrete downstream use case nor provide any specific functionality to facilitate handling assertion ops (functional / non-functional). So this PR limit to itself to [`CapabilityBasedPartitioner`](2da6cae43c/torch/fx/passes/infra/partitioner.py (L34)) and make sure it doesn't break graph partition unexpectedly (by adding some test).

For the test case used in PR, a few things to highlight:
* Without assertion, the fused graph is roughly like:
```
class fused(torch.nn.Module):
    def forward(self, a, b):
        fused_1 = self.fused_1(a, b);
        relu = fused_1.relu()
        fused_0 = self.fused_0(fused_1, relu)
        return (fused_0, fused_1)

    class fused_0(torch.nn.Module):
        def forward(self, add_2, relu):
            ... # Logic after relu
            return add_4

    class fused_1(torch.nn.Module):
        def forward(self, a, b):
            ... # Logic before relu, `add_1` is only exposed within this submodule.
            return add_2
```
* With the assertion, the fused graph is roughly like:
```
class fused(torch.nn.Module):
    def forward(self, arg0_1: i64[s0], arg1_1: i64[s0]):
        dep_token0 = ...
        ...
        fused_1 = self.fused_1(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
        ...
        getitem: i64[s0] = fused_1[0] # `getitem` is actually `add_1`
        ...
        relu_default: i64[s0] = torch.ops.aten.relu.default(getitem_1)
        ...
        # For inline assertion. Note that `getitem` which is an output of `fused_1`, is consumed by it.
        select_int: i64[] = torch.ops.aten.select.int(getitem, 0, 0)
        eq_scalar: b8[] = torch.ops.aten.eq.Scalar(select_int, 5)
        dep_token2: f32[] = torch.ops.aten._functional_assert_async.msg(
            eq_scalar, 'assertion error', dep_token = dep_token1
        )
        ...
        getitem_1: i64[s0] = fused_1[1] # `getitem_1` is actually `add_2`
        fused_0: i64[s0] = self.fused_0(getitem_1, relu_default)
        ...

        return (fused_0, getitem_1, dep_token2)

    class fused_0(torch.nn.Module):
        def forward(self, add_tensor_2: i64[s0], relu_default: i64[s0]):
            ... # Logic after relu
            return add_tensor_4

    class fused_1(torch.nn.Module):
        def forward(self, arg0_1: i64[s0], arg1_1: i64[s0]):
            ... # Logic before relu
            # `add_tensor_1` (basically `add_1`) is returned to allow downstream assertion op consumes it.
            return (add_tensor_1, add_tensor_2)
```

As shown above, the extra assertion added (actually regardless whether it's funtionalized or not), it **won't** case extra submodule breakage if the asserted node is an intermediate node within the submodule - here the intermediate node will be returned as extra output of submodule so downstream assertion node can consume it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104287
Approved by: https://github.com/tugsbayasgalan
2023-06-28 22:13:27 +00:00
xuanqi
bf34ecd0c8 [RFC]: Integrate assertions functionalization to export (after AOT export) (#103887)
This PR integrated the assertion functionalization logic into current export logic.

**NOTE:**
I finally decided to do the assertion functionalization after AOT export instead of before for the following reasons:
* The benefit of AOT export is that the graph is already functionalized so things like method call is already transformed to function call. However, if we do it before AOT export, the graph is still in torch level and extra logic like bab21d20eb/torch/_export/pass_base.py (L201-L204C17) will need to be implemented.
* The graph signature is kind of already incorrect after adding runtime assertions currently (this doesn't seem break logic since we already depend on positions instead of FQNs of outputs). This PR also fixed this.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103887
Approved by: https://github.com/avikchaudhuri, https://github.com/tugsbayasgalan
2023-06-27 18:14:29 +00:00
xuanqi
344bab2669 [RFC]: Functionalize assertions (#103757)
The idea here is to create do a graph mutation to:
* Create an initial dependency token at the beginning of the program.
* Replace non-functional version of assertion statements to functional version.
* The functional version of assertion statement will:
  * Accept a dependency token from output of previous functional assertion statement (or the initial dependency token if there isn't any).
  * Generate a dependency token as the output of assertion statement.
  * Augment the output to include the dependency token generated by last assertion statement.

The goal here is to:
* Form an explicit dependency chain and avoid potential reordering during other passes of compiling.
* Make the assertions a part of overall execution graph will affect the final output (or it could potentially be DCEed).

**NOTE:**
* Currently only cover `contrain_range` and WIP to support other assertions. Send out this PR to collect feedback first.
* Here it only focus on implementation itself. Will integrate it with current export in future PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103757
Approved by: https://github.com/avikchaudhuri
2023-06-24 00:23:35 +00:00
xuanqi
b27c3558a4 [RFC]: Create aten native op for constrain_range (#103346)
At high current implementation of constrains functions (constrain_as_**) will raise exception for the following code snippets:
```
def f(x):
    a = x.item()
    constrain_as_size(a, 4, 7)
    return torch.empty((a, 4))

inp = torch.tensor([5])
ep = torch._export.export(f, (inp,))
```

The reason is because current constrain logic is:
1) Purely python so it won't survive AOT export (the full node is gone after AOT export since AOT export only maintains aten level op).
2) Utilize side effect to add range constraints for traced symbol's shape env ([code](9591e52880/torch/fx/experimental/symbolic_shapes.py (L370-L372))).
3) If runtime assertion is turned on (by default). [`_AddRuntimeAssertionsForConstraintsPass`](9591e52880/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py (L98-L100)) will try to append assertion node based on range constrains extracted from shape env of symbol during another interpretation round.
4). However, since 1), in the round of AOT export, range constraints logic won't run for symbols generated during this round. And later there is no range constrains information available for assertion round and caused issue.
5) As a result of above, it will failure at `torch.empty((a, 4))` (there is no constrains for `a` that it must be positive).

The fix here is just to implement range constrain logic as a native aten op (CPU implementation as no-op) to make it be able to survive AOT export.

**NOTE:**
[Logic](2d745b95d7/torch/fx/experimental/symbolic_shapes.py (L350-L365C15)) within [`constrain_range`](2d745b95d7/torch/fx/experimental/symbolic_shapes.py (LL313C74-L313C74)) is split out as `constrain_range_int` to capture case when non `SymInt` is passed in and reused in the new `_constrain_range`. The reason is when non `SymInt` is provided:
* If it directly calls `sym_constrain_range`, the C++ version will be called which will be no-op.
* So in this case it calls `constrain_range_int` instead to be able to capture issue like user provides a input whose tensor's shape could be out of range during exporting, like the following for above code example:
```
...
inp = torch.tensor([10])
ep = torch._export.export(f, (inp,)) # immediately raise error
```

Differential Revision: [D46734204](https://our.internmc.facebook.com/intern/diff/D46734204)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103346
Approved by: https://github.com/tugsbayasgalan
2023-06-16 14:55:40 +00:00
Tugsbayasgalan Manlaibaatar
4bb2b65ea4 Turn on add_runtime_assertion by default (#102671)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102671
Approved by: https://github.com/angelayi, https://github.com/avikchaudhuri
2023-06-05 16:27:44 +00:00
Angela Yi
7a569f86a0 [export] Cleanup constraints (#102666)
Redo of https://github.com/pytorch/pytorch/pull/102432 because idk how to push to that other branch...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102666
Approved by: https://github.com/zhxchen17
2023-06-01 04:22:31 +00:00
Tugsbayasgalan Manlaibaatar
d9f75dded1 [export] Add aot_export 1/N (#101490)
This PR adds aot_export_module as the lowering path from torch.level graph to aten graph. Some known limitations that need to be addressed in the follow up PRs:
1. Store param/buffer data in ExportedProgram
2. Fully support torch.cond with params/buffers
3. Making sure no duplicated ExportMetaData entry
4. This API will break Executorch if used on PyE, we will figure out a plan internally.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101490
Approved by: https://github.com/avikchaudhuri
2023-05-31 20:56:21 +00:00
Angela Yi
c4028de462 [export] ExportedProgram (#102259)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102259
Approved by: https://github.com/ydwu4, https://github.com/avikchaudhuri, https://github.com/tugsbayasgalan, https://github.com/zhxchen17
2023-05-26 23:36:38 +00:00
Avik Chaudhuri
8751002215 equality assertions (#102256)
Previously we had runtime asserts for range constraints. This diff adds runtime asserts for equality constraints.

This requires a bit of refactoring that is worth calling out.
1. [Minor] Some of the data structures produced by export and consumed by the runtime assertion pass need to be broadened. This is a WIP. There are some associated code improvements that are included in this diff, but by and large the structures are similar to what exists now. Meanwhile @angelayi and I are chatting about how to make it qualitatively better: briefly, we want to index everything by symbols, which are 1-1 with (name, dim) pairs.
2. [Major] The order in which runtime asserts are emitted is changed. Previously we used to do the work in `placeholder`, now this diff adds a hook for "post-processing" after processing of all placeholders is done. This is needed because equality constraints can mention different placeholders. This change also opens the way to optimizing codegen: e.g., each (name, dim) pair should correspond to a single intermediate variable that is reused across runtime asserts. This is future work.

Differential Revision: [D46177642](https://our.internmc.facebook.com/intern/diff/D46177642/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102256
Approved by: https://github.com/tugsbayasgalan, https://github.com/angelayi
2023-05-26 14:57:31 +00:00
Tugsbayasgalan Manlaibaatar
47f43ed84a Actually functionalize torch.export (#101433)
I thought i enabled this, but apparently not. This PR makes the export fully functional for real this time :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101433
Approved by: https://github.com/angelayi
2023-05-17 05:09:24 +00:00
PyTorch MergeBot
eac5f2a8e4 Revert "Actually functionalize torch.export (#101433)"
This reverts commit eec752ed05.

Reverted https://github.com/pytorch/pytorch/pull/101433 on behalf of https://github.com/PaliC due to causing failures on functorch macOS tests ([comment](https://github.com/pytorch/pytorch/pull/101433#issuecomment-1550111671))
2023-05-16 17:51:45 +00:00
Tugsbayasgalan Manlaibaatar
eec752ed05 Actually functionalize torch.export (#101433)
I thought i enabled this, but apparently not. This PR makes the export fully functional for real this time :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101433
Approved by: https://github.com/angelayi
2023-05-16 16:22:13 +00:00
Tugsbayasgalan Manlaibaatar
194d360329 Add more canonical way of adding runtime pass (#100956)
* #100955
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100956
Approved by: https://github.com/ydwu4, https://github.com/guangy10
2023-05-16 03:23:04 +00:00
Tugsbayasgalan Manlaibaatar
9ffad5b62b Remove input tracker from runtime assertion pass (#100955)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100955
Approved by: https://github.com/ydwu4
2023-05-15 21:26:47 +00:00
Tugsbayasgalan Manlaibaatar
f542b31c9d [export] More robust view->view_copy pass (#100908)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100908
Approved by: https://github.com/ydwu4
2023-05-10 14:25:17 +00:00
ydwu4
26cd958718 Support runtime assertion for inline constraints (#100763)
This pr does the following:
1. previously, inline constraints is not properly set for tensor output data-dependent ops such as a.nonzero because of its return value is not symint. This pr just uses all the unbacked symbols i.e.those start with "i"/"f" in create_unbacked_sym* functions. Note that these symbols are guaranteed to be a super set of inline user constraints.

2. add inline assertions support by checking.

Currently, it only deal with tensor, SymInt, SymFloat, SymBool output data-dependent ops and ignore the rest. It's good enough for now as we only have a limited number of data-dependent ops (.item and .nonzero are explicitly tested).

The examples for graph that is added assertions is shown below:

```
class ExportGraphModule(torch.nn.Module):
    def forward(self, x):
        arg0: i64[s0], = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        nonzero_default: i64[i0, 1] = torch.ops.aten.nonzero.default(arg0);  arg0 = None
        return pytree.tree_unflatten([nonzero_default], self._out_spec)

class GraphModule(torch.nn.Module):
    def forward(self, x):
        arg0: i64[s0], = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        sym_size: Sym(s0) = torch.ops.aten.sym_size(arg0, 0)
        nonzero_default: i64[i1, 1] = torch.ops.aten.nonzero.default(arg0);  arg0 = None
        sym_size_1: Sym(i1) = torch.ops.aten.sym_size(nonzero_default, 0)
        ge: Sym(i1 >= 3) = sym_size_1 >= 3
        scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(ge);  ge = None
        _assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'nonzero_default.shape[0] is outside of inline constraint [3, 5].');  scalar_tensor_default = None
        le: Sym(i1 <= 5) = sym_size_1 <= 5;  sym_size_1 = None
        scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(le);  le = None
        _assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'nonzero_default.shape[0] is outside of inline constraint [3, 5].');  scalar_tensor_default_1 = None
        return pytree.tree_unflatten([nonzero_default], self._out_spec)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100763
Approved by: https://github.com/tugsbayasgalan
2023-05-09 04:19:57 +00:00
Angela Yi
2d2f716ddc [export] Fix cond for pass_base (#100836)
I ported over the code for the inline interpreter incorrectly in the pass base 😅

Originally the function `make_inline_interpreter` is supposed to take in a fx.Interpreter type but I accidentally passed in an fx.Interpreter object. Also realized while modifying this diff (and comments from Tugsuu) that we don't really need this InlineInterpreter.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100836
Approved by: https://github.com/zhxchen17, https://github.com/tugsbayasgalan
2023-05-08 21:51:03 +00:00
Tugsbayasgalan Manlaibaatar
9b3552eb2c Add runtime assertions for input shape constraints (#100247)
This PR adds runtime assertions as an extra pass in the exported graph. Several high level information:
1. We specialize all dimensions that were not added to the user input constraints
2. We haven't added relational constraints as runtime assertions (e.g x[1] == x[0]), will do in a follow up diff

Differential Revision: [D45408971](https://our.internmc.facebook.com/intern/diff/D45408971)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100247
Approved by: https://github.com/guangy10, https://github.com/avikchaudhuri
2023-05-04 13:26:58 +00:00
Angela Yi
7bece142a9 [export] Port over const prop pass (#100102)
Stacked on top of https://github.com/pytorch/pytorch/pull/100000
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100102
Approved by: https://github.com/gmagogsfm
2023-04-27 17:06:47 +00:00
Angela Yi
9bbd3d6489 [export] ExportPassBase + view_copy pass (#100000)
* Added ExportPassBase, an interpreter based helper pass writing class
* It can also help maintain the dialect based on the operator namespace through having users override the `get_valid_dialects` function (returning an empty lists implies the pass works for any dialect).
* Added a `ReplaceBrokenOpsWithFunctionalOpsPass` to replace all ops that have not been converted with functionalization with their functional ones.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100000
Approved by: https://github.com/gmagogsfm
2023-04-26 21:01:25 +00:00