Summary: Internal model and Resnet uses "re-export" flow now. Also did some refactoring to make the code little cleaner
Some changes for OSS:
1. Correctly use the "cached" fake tensors so that static symbols are still resolved to static
2. Change logic in PassBase to allocate static shapes for parameters
3. Add "is_torch_exported" tag to every node to make it survive during various graph transformations.
4. Added experimental wrapper API for quantization team to get pre_dispatch=True graph. Note that it doesn't actually do that right now. But we plan to switch soon.
Test Plan: CI
Differential Revision: D47890878
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106676
Approved by: https://github.com/jerryzh168
https://github.com/pytorch/pytorch/issues/105555
Existing flow first exports and then calls torch._inductor.aot_compile. However, export calls aot_autograd with the core aten decomposition table, and then torch._inductor.aot_compile calls aot_autograd again with the inductor decomposition table. The 2nd calling of aot_autograd is supposedly causing some problems, and seems excessive, so instead we will create a new function, torch._export.aot_compiler which will export using the inductor decomposition table, pass it to inductor's compile_fx_aot, and because it has already been exported, avoid recalling aot_autograd.
```
def aot_compile(
f: Callable,
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
constraints: Optional[List[Constraint]] = None,
) -> Tuple[str, ExportedProgram]:
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105977
Approved by: https://github.com/desertfire, https://github.com/zhxchen17, https://github.com/eellison
Summary: When we do a deep copy of the ExportedProgram because of the custom deep copy override the graph metadata (graph.meta) is failing to be copied over. This can be fixed but overall i don't see a need for a custom deepcopy in ExportedProgram and thus trying to get rid of it.
Test Plan: CI
Differential Revision: D48043723
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106578
Approved by: https://github.com/JacobSzwejbka
Summary:
Adding support for edge dialect ops in `exir/serde`. This diff does the following:
- Moves the global `serialize_operator/deserialize_operator` implementations in`export/serde/serialize.py` into `GraphModuleSerializer` and `GraphModuleDeserializer`
- Adds implementations of `serialize_operator/deserialize_operator` inside `GraphModuleSerializer` and `GraphModuleDeserializer` in `exir/serde/serialize.py`
Test Plan: CI + Enabled edge dialect ops in `executorch/exir/tests/test_serde.py`
Differential Revision: D47938280
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106371
Approved by: https://github.com/angelayi
@SherlockNoMad mentioned that it's not bc safe to directly access these attributes, so I moved them to @property fields, and added a `@compatibility` decorator. For now I just set it to True for graph_module/graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106170
Approved by: https://github.com/SherlockNoMad
* Enables PIE807 + PIE810. PIE807 is do not reimplement list builtin function using lambda and PIE810 is to always fuse startswith / endswith calls (I applied the autofixes for this before we had ruff enabled).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106218
Approved by: https://github.com/albanD
In this pr, we allow users to register a customized flatten/unflatten/serialization/deserialization for a dataclass. We provide some default implementation for flatten/unflatten. We could implement a decorator based on it when needed.
## Motivation:
HuggingFace and many internal models return dataclass output and torch.export wants to maintain the invariant that export result (i.e. exported_program) has the same calling convention and result as the original callable.
This is not supported in export yet: we cannot recover the original dataclass from flattened output produced by the underlying graph module (produced by dynamo and processed further by aot_export). We need to have a place to store the metadata of the dataclass so that we can re-construct it. To avoid adding hacky code in export and allow princinpled extensibility, we think extending pytree may be a good option.
## Implementation:
@zou3519 mentioned https://github.com/pytorch/pytorch/pull/93214/files and [jax-2371](https://github.com/google/jax/issues/2371#issuecomment-805361566), which suggests that it's not a good idea to make dataclass a default pytree node but it could be good to provide a default implementation for dataclass. Since currently, this seems to be an export-only feature, we added this extension point in export.
We also add "return_none_fields" flag to control whether none fields are returned after flattening, which is expected to be False in produce_matching of dynamo.export.
Also added some tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106160
Approved by: https://github.com/zhxchen17
Summary: Sometimes the graph that is being serialized contains nodes with side effects + no users (ex. out variants of operators), so we don't want to eliminate those when deserializing.
Test Plan: CI
Differential Revision: D47735009
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105875
Approved by: https://github.com/ydwu4
Summary: ExirExportedProgram would like to have this feature. Today it does it itself since it inherits from ExportedProgram but since we are moving it to composition I think it would be cleaner to upstream the behavior into the root object anyway
Test Plan: ci, but todo where are the tests for this file?
Differential Revision: D47645843
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105852
Approved by: https://github.com/tugsbayasgalan
Solving #105242.
During export, the exported function's signature changes multiple times. Suppose we'd like to export f as shown in following example:
```python
def f(arg1, arg2, kw1, kw2):
pass
args = (arg1, arg2)
kwargs = {"kw2":arg3, "kw1":arg4}
torch.export(f, args, kwargs)
```
The signature changes mutiple times during export process in the following order:
1. **gm_torch_level = dynamo.export(f, *args, \*\*kwargs)**. In this step, we turn all kinds of parameters such as **postional_only**, **var_positioinal**, **kw_only**, and **var_kwargs** into **positional_or_kw**.It also preserves the positional and kword argument names in original function (i.e. f in this example) [here](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/export.py#L546C13-L546C27). The order of kwargs will be the **key order** of kwargs (after python 3.6, the order is the insertion of order of keys) instead of the original function signature and the order is baked into a _orig_args varaible of gm_torch_level's pytree info. So we'll have:
```python
def gm_torch_level(arg1, arg2, kw2, kw1)
```
Such difference is acceptable as it's transparent to users of export.
2. **gm_aot_export = aot_export_module(gm_torch_level, pos_or_kw_args)**. In this step, we need to turn kwargs into positional args in the order of how gm_torch_level expected, which is stored in _orig_args. The returned gm_aot_export has the graph signature of flat_args, in_spec = pytree.tree_flatten(pos_or_kw_args):
``` python
flat_args, _ = pytree.tree_flatten(pos_or_kw_args)
def gm_aot_export(*flat_args)
```
3. **exported_program(*args, \*\*kwargs)**. The epxorted artifact is exported_program, which is a wrapper over gm_aot_export and has the same calling convention as the original function "f". To do this, we need to 1. specialize the order of kwargs into pos_or_kw_args and 2. flatten the pos_or_kw_args into what gm_aot_export expected. We can combine the two steps into one with :
```python
_, in_spec = pytree.tree_flatten((args, kwargs))
# Then during exported_program.__call__(*args, **kwargs)
flat_args = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
```
, where kwargs is treated as a normal pytree whose keyorder is preserved in in_spec.
Implementation-wise, we treat _orig_args in dynamo exported graph module as single source of truth and kwags are ordered following it.
Test plan:
See added tests in test_export.py.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105337
Approved by: https://github.com/angelayi, https://github.com/tugsbayasgalan
Summary: We wanna do this little by little. For now, I tried only on DissectedPartsModel which needs to use aot_export version.
Test Plan: CI
Reviewed By: zhxchen17
Differential Revision: D46785735
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104897
Approved by: https://github.com/JacobSzwejbka
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)
That were reverted due to the conflict with internal source repo.
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
- Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
- Add missing return statement to `torch._export. deserialize_graph`
- Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
- Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
- Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
Summary: Submodules may have a none call-spec values, which is ok. Updating types + serializer to handle this
Test Plan: CI
Reviewed By: ydwu4, zhxchen17
Differential Revision: D47353101
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105179
Approved by: https://github.com/zhxchen17
This reduces total number of imported modules by default from 1419 to 1322 according to
```
time python -c "import sys;before=len(sys.modules);import torch;after=len(sys.modules);print(f'torch-{torch.__version__} imported {after-before} modules')"
```
and slightly reduces import time, while having no effect on UX (i.e. `torch.onnx.` submodule is kept intact)
Suppress lint errors that appear after mypy accidentally starts listing more files, for more details see: https://github.com/pytorch/pytorch/issues/104940
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104843
Approved by: https://github.com/jansel, https://github.com/albanD
Summary:
Some serialized nn_module_stacks contain nested commas, something like:
`(getitem(L['module'],0),torch.nn.modules.linear.Linear)`
Fixing the parsing so that we can deserialize the string in the format of: `(local identifier, module type)`
Test Plan: CI
Differential Revision: D47252881
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104721
Approved by: https://github.com/zhxchen17
This PR:
* Address comment at https://github.com/pytorch/pytorch/pull/103887/files#r1244128266.
* Add test for graph partition to make sure assertion ops functionalization won't break graph partition in unexpected way.
**NOTE**:
In the context of export, it's totally up to the user to any type of graph partition based on specific use case. It's hard to anticipate the concrete downstream use case nor provide any specific functionality to facilitate handling assertion ops (functional / non-functional). So this PR limit to itself to [`CapabilityBasedPartitioner`](2da6cae43c/torch/fx/passes/infra/partitioner.py (L34)) and make sure it doesn't break graph partition unexpectedly (by adding some test).
For the test case used in PR, a few things to highlight:
* Without assertion, the fused graph is roughly like:
```
class fused(torch.nn.Module):
def forward(self, a, b):
fused_1 = self.fused_1(a, b);
relu = fused_1.relu()
fused_0 = self.fused_0(fused_1, relu)
return (fused_0, fused_1)
class fused_0(torch.nn.Module):
def forward(self, add_2, relu):
... # Logic after relu
return add_4
class fused_1(torch.nn.Module):
def forward(self, a, b):
... # Logic before relu, `add_1` is only exposed within this submodule.
return add_2
```
* With the assertion, the fused graph is roughly like:
```
class fused(torch.nn.Module):
def forward(self, arg0_1: i64[s0], arg1_1: i64[s0]):
dep_token0 = ...
...
fused_1 = self.fused_1(arg0_1, arg1_1); arg0_1 = arg1_1 = None
...
getitem: i64[s0] = fused_1[0] # `getitem` is actually `add_1`
...
relu_default: i64[s0] = torch.ops.aten.relu.default(getitem_1)
...
# For inline assertion. Note that `getitem` which is an output of `fused_1`, is consumed by it.
select_int: i64[] = torch.ops.aten.select.int(getitem, 0, 0)
eq_scalar: b8[] = torch.ops.aten.eq.Scalar(select_int, 5)
dep_token2: f32[] = torch.ops.aten._functional_assert_async.msg(
eq_scalar, 'assertion error', dep_token = dep_token1
)
...
getitem_1: i64[s0] = fused_1[1] # `getitem_1` is actually `add_2`
fused_0: i64[s0] = self.fused_0(getitem_1, relu_default)
...
return (fused_0, getitem_1, dep_token2)
class fused_0(torch.nn.Module):
def forward(self, add_tensor_2: i64[s0], relu_default: i64[s0]):
... # Logic after relu
return add_tensor_4
class fused_1(torch.nn.Module):
def forward(self, arg0_1: i64[s0], arg1_1: i64[s0]):
... # Logic before relu
# `add_tensor_1` (basically `add_1`) is returned to allow downstream assertion op consumes it.
return (add_tensor_1, add_tensor_2)
```
As shown above, the extra assertion added (actually regardless whether it's funtionalized or not), it **won't** case extra submodule breakage if the asserted node is an intermediate node within the submodule - here the intermediate node will be returned as extra output of submodule so downstream assertion node can consume it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104287
Approved by: https://github.com/tugsbayasgalan
This PR integrated the assertion functionalization logic into current export logic.
**NOTE:**
I finally decided to do the assertion functionalization after AOT export instead of before for the following reasons:
* The benefit of AOT export is that the graph is already functionalized so things like method call is already transformed to function call. However, if we do it before AOT export, the graph is still in torch level and extra logic like bab21d20eb/torch/_export/pass_base.py (L201-L204C17) will need to be implemented.
* The graph signature is kind of already incorrect after adding runtime assertions currently (this doesn't seem break logic since we already depend on positions instead of FQNs of outputs). This PR also fixed this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103887
Approved by: https://github.com/avikchaudhuri, https://github.com/tugsbayasgalan
Summary:
ETRecord can't use this yet because the other programs need to be migrated to using ExportedProgram (D46729844)
Note: higher order ops like call_delegate/cond are also not supported yet
Test Plan: `buck2 run @//mode/dev-nosan //executorch/exir/tests:serde`
Differential Revision: D46802454
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103763
Approved by: https://github.com/tarun292
The idea here is to create do a graph mutation to:
* Create an initial dependency token at the beginning of the program.
* Replace non-functional version of assertion statements to functional version.
* The functional version of assertion statement will:
* Accept a dependency token from output of previous functional assertion statement (or the initial dependency token if there isn't any).
* Generate a dependency token as the output of assertion statement.
* Augment the output to include the dependency token generated by last assertion statement.
The goal here is to:
* Form an explicit dependency chain and avoid potential reordering during other passes of compiling.
* Make the assertions a part of overall execution graph will affect the final output (or it could potentially be DCEed).
**NOTE:**
* Currently only cover `contrain_range` and WIP to support other assertions. Send out this PR to collect feedback first.
* Here it only focus on implementation itself. Will integrate it with current export in future PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103757
Approved by: https://github.com/avikchaudhuri
Summary:
We remove the ExportGraphModuleMixin. There are several implications of this change:
1. The graph_module of ExportedProgram, EdgeDialectProgram and ExecutorchProgram won't have the same signature as original user function. Instead, we should directly call the *Program, which has the same calling convention. e.g:
2. All passes need to go through prog.transform(*passes). We need to make all passes return PassResult as a result.
3. We also need to make sure the graph_module.meta is preserved after transform.
Test Plan: Test with CI.
Differential Revision: D46729844
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103786
Approved by: https://github.com/avikchaudhuri
At high current implementation of constrains functions (constrain_as_**) will raise exception for the following code snippets:
```
def f(x):
a = x.item()
constrain_as_size(a, 4, 7)
return torch.empty((a, 4))
inp = torch.tensor([5])
ep = torch._export.export(f, (inp,))
```
The reason is because current constrain logic is:
1) Purely python so it won't survive AOT export (the full node is gone after AOT export since AOT export only maintains aten level op).
2) Utilize side effect to add range constraints for traced symbol's shape env ([code](9591e52880/torch/fx/experimental/symbolic_shapes.py (L370-L372))).
3) If runtime assertion is turned on (by default). [`_AddRuntimeAssertionsForConstraintsPass`](9591e52880/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py (L98-L100)) will try to append assertion node based on range constrains extracted from shape env of symbol during another interpretation round.
4). However, since 1), in the round of AOT export, range constraints logic won't run for symbols generated during this round. And later there is no range constrains information available for assertion round and caused issue.
5) As a result of above, it will failure at `torch.empty((a, 4))` (there is no constrains for `a` that it must be positive).
The fix here is just to implement range constrain logic as a native aten op (CPU implementation as no-op) to make it be able to survive AOT export.
**NOTE:**
[Logic](2d745b95d7/torch/fx/experimental/symbolic_shapes.py (L350-L365C15)) within [`constrain_range`](2d745b95d7/torch/fx/experimental/symbolic_shapes.py (LL313C74-L313C74)) is split out as `constrain_range_int` to capture case when non `SymInt` is passed in and reused in the new `_constrain_range`. The reason is when non `SymInt` is provided:
* If it directly calls `sym_constrain_range`, the C++ version will be called which will be no-op.
* So in this case it calls `constrain_range_int` instead to be able to capture issue like user provides a input whose tensor's shape could be out of range during exporting, like the following for above code example:
```
...
inp = torch.tensor([10])
ep = torch._export.export(f, (inp,)) # immediately raise error
```
Differential Revision: [D46734204](https://our.internmc.facebook.com/intern/diff/D46734204)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103346
Approved by: https://github.com/tugsbayasgalan
v2 of https://github.com/pytorch/pytorch/pull/102125 because of git issues
corresponding deserialization diff: https://github.com/pytorch/pytorch/pull/102716
Implementing serialization of the exported program to a python dataclass, and then from that dataclass to json. This is split into a couple of sections:
- `serialize(ep: ep.ExportedProgram, opset_version: Dict[str, int]) -> Tuple[bytes, bytes]` -- takes an exported program object, a dictionary mapping opset namespaces to versions, and returns the serialized exported program in bytes, and separately the state dict serialized in bytes
- `GraphModuleSerializer` class that serializes torch.fx.GraphModule
to the schema.GraphModule dataclass
- `ExportedProgramSerializer` class that serializes torch._export.exported_program.ExportedProgram to the schema.ExportedProgram dataclass
Serialization TODOs:
- [x] pytree spec: https://github.com/pytorch/pytorch/pull/102577
- [ ] higher order ops
- [ ] node metadata (specifically nn_module_stack/source_fn)
- [ ] constraints
- [ ] graph module metadata
The tests are not super comprehensive, but that's because I think it'll be better tested + easier to test once deserialization is implemented.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102707
Approved by: https://github.com/avikchaudhuri, https://github.com/zhxchen17
There are some I can't easily switch due to reasons like:
- Dynamo modelling the guard
- BC concerns (for torch.autograd.set_multithreading_enabled)
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102642
Approved by: https://github.com/albanD
This PR adds aot_export_module as the lowering path from torch.level graph to aten graph. Some known limitations that need to be addressed in the follow up PRs:
1. Store param/buffer data in ExportedProgram
2. Fully support torch.cond with params/buffers
3. Making sure no duplicated ExportMetaData entry
4. This API will break Executorch if used on PyE, we will figure out a plan internally.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101490
Approved by: https://github.com/avikchaudhuri
Previously we had runtime asserts for range constraints. This diff adds runtime asserts for equality constraints.
This requires a bit of refactoring that is worth calling out.
1. [Minor] Some of the data structures produced by export and consumed by the runtime assertion pass need to be broadened. This is a WIP. There are some associated code improvements that are included in this diff, but by and large the structures are similar to what exists now. Meanwhile @angelayi and I are chatting about how to make it qualitatively better: briefly, we want to index everything by symbols, which are 1-1 with (name, dim) pairs.
2. [Major] The order in which runtime asserts are emitted is changed. Previously we used to do the work in `placeholder`, now this diff adds a hook for "post-processing" after processing of all placeholders is done. This is needed because equality constraints can mention different placeholders. This change also opens the way to optimizing codegen: e.g., each (name, dim) pair should correspond to a single intermediate variable that is reused across runtime asserts. This is future work.
Differential Revision: [D46177642](https://our.internmc.facebook.com/intern/diff/D46177642/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102256
Approved by: https://github.com/tugsbayasgalan, https://github.com/angelayi
Summary: Forward fix t53725825. New map implementation breaks multiple internal tests. forward fix it for some of them. To unblock others, mark unfixed ones are expectedFailure first.
Test Plan: Test with CI.
Reviewed By: angelayi
Differential Revision: D46084287
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102009
Approved by: https://github.com/angelayi
This pr does the following:
1. previously, inline constraints is not properly set for tensor output data-dependent ops such as a.nonzero because of its return value is not symint. This pr just uses all the unbacked symbols i.e.those start with "i"/"f" in create_unbacked_sym* functions. Note that these symbols are guaranteed to be a super set of inline user constraints.
2. add inline assertions support by checking.
Currently, it only deal with tensor, SymInt, SymFloat, SymBool output data-dependent ops and ignore the rest. It's good enough for now as we only have a limited number of data-dependent ops (.item and .nonzero are explicitly tested).
The examples for graph that is added assertions is shown below:
```
class ExportGraphModule(torch.nn.Module):
def forward(self, x):
arg0: i64[s0], = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
nonzero_default: i64[i0, 1] = torch.ops.aten.nonzero.default(arg0); arg0 = None
return pytree.tree_unflatten([nonzero_default], self._out_spec)
class GraphModule(torch.nn.Module):
def forward(self, x):
arg0: i64[s0], = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
sym_size: Sym(s0) = torch.ops.aten.sym_size(arg0, 0)
nonzero_default: i64[i1, 1] = torch.ops.aten.nonzero.default(arg0); arg0 = None
sym_size_1: Sym(i1) = torch.ops.aten.sym_size(nonzero_default, 0)
ge: Sym(i1 >= 3) = sym_size_1 >= 3
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(ge); ge = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'nonzero_default.shape[0] is outside of inline constraint [3, 5].'); scalar_tensor_default = None
le: Sym(i1 <= 5) = sym_size_1 <= 5; sym_size_1 = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(le); le = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'nonzero_default.shape[0] is outside of inline constraint [3, 5].'); scalar_tensor_default_1 = None
return pytree.tree_unflatten([nonzero_default], self._out_spec)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100763
Approved by: https://github.com/tugsbayasgalan
I ported over the code for the inline interpreter incorrectly in the pass base 😅
Originally the function `make_inline_interpreter` is supposed to take in a fx.Interpreter type but I accidentally passed in an fx.Interpreter object. Also realized while modifying this diff (and comments from Tugsuu) that we don't really need this InlineInterpreter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100836
Approved by: https://github.com/zhxchen17, https://github.com/tugsbayasgalan
* Added ExportPassBase, an interpreter based helper pass writing class
* It can also help maintain the dialect based on the operator namespace through having users override the `get_valid_dialects` function (returning an empty lists implies the pass works for any dialect).
* Added a `ReplaceBrokenOpsWithFunctionalOpsPass` to replace all ops that have not been converted with functionalization with their functional ones.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100000
Approved by: https://github.com/gmagogsfm
Wrapper for users to insert constraints into model code.
The constraints will not be maintained in the graph after tracing through make_fx so retracing with dynamo/make_fx will not work. This will be supported after torch._assert supported is implemented. Then we can convert the constrain_range calls to torch._asserts.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98433
Approved by: https://github.com/avikchaudhuri, https://github.com/tugsbayasgalan
Wrapper for users to insert constraints into model code.
The constraints will not be maintained in the graph after tracing through make_fx so retracing with dynamo/make_fx will not work. This will be supported after torch._assert supported is implemented. Then we can convert the constrain_range calls to torch._asserts.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98433
Approved by: https://github.com/avikchaudhuri, https://github.com/tugsbayasgalan
The purpose of this API is to execute a few large components of work:
1) Refactor all the internals of plumbing dynamic dimension information after dynamo to be stateless
2) Decouple allocation controls around dynamic dimensions from verification
3) For (2), for allocation, create an enum that dictates whether we are in DUCK (default today), STATIC (aka assume_static_default in the past), or DYNAMIC (aka user constrained, do not duck shape)
4) For (2), for verification, we separate out the list of dynamic ranges entirely from allocation. This means shape_env does not tracking for what we verify on, and instead, it is the callers job to invoke produce_guards() with the various things they want verified, specifically, with the valid ranges. We do use constrain ranges to refine value ranges when doing analysis.
5) We have decided, therefore, as an extension of (4) to double down on "late" checks versus "eager" checks, primarily because the mechanisms for gathering what actually matters happens during guards, and should be a purview of the caller seeking guards, not the shape env. However, for dynamo, these structures are essentially one and the same.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96699
Approved by: https://github.com/avikchaudhuri, https://github.com/ezyang
This makes the next PR in the stack cleaner: having the top level entry point to aot autograd perform the functionalization analysis pass once, and plumb the metadata everywhere else that we need it.
I put it in a separate PR because I recently learned that this function is used in fbcode, so I'll need to fix up internals when I land this PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95991
Approved by: https://github.com/ezyang
OK, so this PR used to be about reducing the number of constants we specialize on, but it turns out that unspecialization was ~essentially never used (because we still constant specialized way too aggressively) and I ended up having to fix a bunch of issues to actually get tests to pass. So this PR is now "make int unspecialization actually work". As part of this, I have to turn off unspecialization by default, as there are still latent bugs in inductor.
The general strategy is that an unspecialized int is represented as a SymInt. Representing it as a 0d tensor (which is what the code used to do) is untenable: (1) we often need unspecialized ints to participate in size computations, but we have no way of propagating sympy expressions through tensor compute, and (2) a lot of APIs work when passed SymInt, but not when passed a Tensor. However, I continue to represent Numpy scalars as Tensors, as they are rarely used for size computation and they have an explicit dtype, so they are more accurately modeled as 0d tensors.
* I folded in the changes from https://github.com/pytorch/pytorch/pull/95099 as I cannot represent unspecialized ints as SymInts without also turning on dynamic shapes. This also eliminates the necessity for test_unspec.py, as toggling specialization without dynamic shapes doesn't do anything. As dynamic shapes defaults to unspecializing, I just deleted this entirely; for the specialization case, I rely on regular static shape tests to catch it. (Hypothetically, we could also rerun all the tests with dynamic shapes, but WITH int/float specialization, but this seems... not that useful? I mean, I guess export wants it, but I'd kind of like our Source heuristic to improve enough that export doesn't have to toggle this either.)
* Only 0/1 integers get specialized by default now
* A hodgepodge of fixes. I'll comment on the PR about them.
Fixes https://github.com/pytorch/pytorch/issues/95469
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95621
Approved by: https://github.com/jansel, https://github.com/Chillee
This is WIP PR for adding torch.export API in OSS. Couple of points:
- I intentionally named it as experimental_export so that ppl don't get confused thinking this is our official API
- We don't plan to use AOTAutograd backend just yet. The reason we have it here is because the functionalization AOTAutograd uses is what we need for export (handling of param/buffer mutation etc). In the near future, I will extract the functionalization part and use it on top of make_fx. What we have right now is merely a placeholder.
- The reason we want to do it now is because we want to have some minimal tests running in OSS so that we can catch regressions earlier.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95070
Approved by: https://github.com/gmagogsfm, https://github.com/zhxchen17
Scalar is a union type of [int, float, bool], it's only needed for the representation of operation schema.
During export, we always have the concrete argument. As ex.Argument is already an union type, we don't need Scalar type anymore.
Example
Here's the schema for aten.add.Scalar
```
add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
```
A fx.node
```
add_tensor: f32[s0, s0] = torch.ops.aten.add.Scalar(arg0, 1.1)
```
would be exported as
```
Node(
op='call_function',
target='aten.add.Tensor',
args=[
Argument(as_tensor=TensorArgument(name='arg0')),
Argument(as_float=1.1)
],
outputs=[
ReturnArgument(as_tensor=TensorArgument(name='add_tensor'))
]
)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93211
Approved by: https://github.com/suo
Node can only be 'call_function' ops
'placeholder' and 'output' are serialized as inputs and outputs of the Graph
'get_attr' is not needed anymore, as it's an implicit lookup from GraphModule's parameters/buffers
'call_method' and 'call_module' is not supported, as it's not used in the canonical FX Graph
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93208
Approved by: https://github.com/suo, https://github.com/Neilblaze