Commit Graph

5 Commits

Author SHA1 Message Date
xuanqi
b27c3558a4 [RFC]: Create aten native op for constrain_range (#103346)
At high current implementation of constrains functions (constrain_as_**) will raise exception for the following code snippets:
```
def f(x):
    a = x.item()
    constrain_as_size(a, 4, 7)
    return torch.empty((a, 4))

inp = torch.tensor([5])
ep = torch._export.export(f, (inp,))
```

The reason is because current constrain logic is:
1) Purely python so it won't survive AOT export (the full node is gone after AOT export since AOT export only maintains aten level op).
2) Utilize side effect to add range constraints for traced symbol's shape env ([code](9591e52880/torch/fx/experimental/symbolic_shapes.py (L370-L372))).
3) If runtime assertion is turned on (by default). [`_AddRuntimeAssertionsForConstraintsPass`](9591e52880/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py (L98-L100)) will try to append assertion node based on range constrains extracted from shape env of symbol during another interpretation round.
4). However, since 1), in the round of AOT export, range constraints logic won't run for symbols generated during this round. And later there is no range constrains information available for assertion round and caused issue.
5) As a result of above, it will failure at `torch.empty((a, 4))` (there is no constrains for `a` that it must be positive).

The fix here is just to implement range constrain logic as a native aten op (CPU implementation as no-op) to make it be able to survive AOT export.

**NOTE:**
[Logic](2d745b95d7/torch/fx/experimental/symbolic_shapes.py (L350-L365C15)) within [`constrain_range`](2d745b95d7/torch/fx/experimental/symbolic_shapes.py (LL313C74-L313C74)) is split out as `constrain_range_int` to capture case when non `SymInt` is passed in and reused in the new `_constrain_range`. The reason is when non `SymInt` is provided:
* If it directly calls `sym_constrain_range`, the C++ version will be called which will be no-op.
* So in this case it calls `constrain_range_int` instead to be able to capture issue like user provides a input whose tensor's shape could be out of range during exporting, like the following for above code example:
```
...
inp = torch.tensor([10])
ep = torch._export.export(f, (inp,)) # immediately raise error
```

Differential Revision: [D46734204](https://our.internmc.facebook.com/intern/diff/D46734204)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103346
Approved by: https://github.com/tugsbayasgalan
2023-06-16 14:55:40 +00:00
Tugsbayasgalan (Tugsuu) Manlaibaatar
02f059c2b7 Add private _export API (#99992)
Differential Revision: D45279206

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99992
Approved by: https://github.com/angelayi, https://github.com/gmagogsfm
2023-04-27 16:24:14 +00:00
Angela Yi
1d077f28ed [export] Constraints API (#98433)
Wrapper for users to insert constraints into model code.

The constraints will not be maintained in the graph after tracing through make_fx so retracing with dynamo/make_fx will not work. This will be supported after torch._assert supported is implemented. Then we can convert the constrain_range calls to torch._asserts.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98433
Approved by: https://github.com/avikchaudhuri, https://github.com/tugsbayasgalan
2023-04-13 21:20:10 +00:00
PyTorch MergeBot
ab761605ae Revert "[export] Constraints API (#98433)"
This reverts commit 1510eb4072.

Reverted https://github.com/pytorch/pytorch/pull/98433 on behalf of https://github.com/izaitsevfb due to Breaks internal tests, asked by author to revert
2023-04-12 23:37:19 +00:00
Angela Yi
1510eb4072 [export] Constraints API (#98433)
Wrapper for users to insert constraints into model code.

The constraints will not be maintained in the graph after tracing through make_fx so retracing with dynamo/make_fx will not work. This will be supported after torch._assert supported is implemented. Then we can convert the constrain_range calls to torch._asserts.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98433
Approved by: https://github.com/avikchaudhuri, https://github.com/tugsbayasgalan
2023-04-12 01:32:44 +00:00