Commit Graph

34 Commits

Author SHA1 Message Date
angelayi
cbbc309cae [pytree][reland] Require pytree serialized_type_name (#120636)
Relanding https://github.com/pytorch/pytorch/pull/119718 as the diff which prevents breakages of torchrec [D53857843](https://www.internalfb.com/diff/D53857843) has landed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120636
Approved by: https://github.com/avikchaudhuri
2024-02-27 06:53:33 +00:00
ydwu4
8d81e61fb6 [export] make node_inline_ also inline the get_item calls (#119913)
As titled. Before the PR, after we split then inline_, there will be getitem calls in the graph while the original graph module doesn't have them. This PR removes the additional get_item calls by inlining.

Test Plan:
Added new test cases for graphs that return multiple outputs and takes multiple inputs
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119913
Approved by: https://github.com/tugsbayasgalan
ghstack dependencies: #119732, #119736, #119810
2024-02-17 02:18:27 +00:00
ydwu4
4769e6916a [export] add node_inline_ to prepare replacing set_grad_enabled with hop (#119736)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119736
Approved by: https://github.com/tugsbayasgalan
ghstack dependencies: #119732
2024-02-17 02:18:11 +00:00
ydwu4
068659ddc2 [export] add sequential_split to prepare replacing set_grad_enabled with hop (#119732)
This pr is the 1/N pr of transforming the global state mutating ops  such as torch._C.set_grad_enabled calls in pre-dispatch graph into a higher order op so that the graph becomes more functional. We make use of split_module to help us do the transformation.

This pr preserves the node.name in original module by adding a new kwarg `keep_original_node_name` to split_module.

For a graph looks like this:
```python
def forward(self, arg_0):
    arg0_1, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
    add = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
    sin = torch.ops.aten.sin.default(add);  add = None
    sum_1 = torch.ops.aten.sum.default(sin);  sin = None
    _set_grad_enabled = torch._C._set_grad_enabled(False)
    add_1 = torch.ops.aten.add.Tensor(sum_1, 1);  sum_1 = None
    _set_grad_enabled_1 = torch._C._set_grad_enabled(True)
    sub = torch.ops.aten.sub.Tensor(add_1, 1)
    return pytree.tree_unflatten((add_1, sub), self._out_spec)
```
Before the change, split graph returns the following graphs and subgraphs (notice the change from `add` -> `add_tensor`, `sin` -> `sin_default`:
```python
def forward(self, arg_0):
    arg0_1, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
    submod_0 = self.submod_0(arg0_1);  arg0_1 = None
    submod_1 = self.submod_1(submod_0);  submod_0 = None
    submod_2 = self.submod_2(submod_1)
    return pytree.tree_unflatten((submod_1, submod_2), self._out_spec)

# submod_0
def forward(self, arg0_1):
    add_tensor = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
    sin_default = torch.ops.aten.sin.default(add_tensor);  add_tensor = None
    sum_default = torch.ops.aten.sum.default(sin_default);  sin_default = None
    return sum_default

# submod_1
def forward(self, sum_1):
    _set_grad_enabled = torch._C._set_grad_enabled(False)
    add_tensor = torch.ops.aten.add.Tensor(sum_1, 1);  sum_1 = None
    return add_tensor

# submod_2
def forward(self, add_1):
    _set_grad_enabled = torch._C._set_grad_enabled(True)
    sub_tensor = torch.ops.aten.sub.Tensor(add_1, 1);  add_1 = None
    return sub_tensor
    """)

```

After the change, the test produce the following graph, all the node names in original graph module are preserved in sub_modules.
```python

def forward(self, arg_0):
    sub, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
    submod_0 = self.submod_0(sub);  sub = None
    submod_1 = self.submod_1(submod_0);  submod_0 = None
    submod_2 = self.submod_2(submod_1)
    return pytree.tree_unflatten((submod_1, submod_2), self._out_spec)

# submod_0
def forward(self, arg0_1):
    add = torch.ops.aten.add.Tensor(arg0_1, 1);  arg0_1 = None
    sin = torch.ops.aten.sin.default(add);  add = None
    sum_1 = torch.ops.aten.sum.default(sin);  sin = None
    return sum_1

# submod_1
def forward(self, sum_1):
    _set_grad_enabled = torch._C._set_grad_enabled(False)
    add_1 = torch.ops.aten.add.Tensor(sum_1, 1);  sum_1 = None
    return add_1

# submod_2
def forward(self, add_1):
    _set_grad_enabled_1 = torch._C._set_grad_enabled(True)
    sub = torch.ops.aten.sub.Tensor(add_1, 1);  add_1 = None
    return sub

```

Note that currently, we call split_module on the graph after pre-dispatch aot. The difference is even larger if we `split_module` the graph module produced by dynamo, where all the original variables names in user program are preserved after dynamo but  lost after `split_module` without this change.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119732
Approved by: https://github.com/tugsbayasgalan
2024-02-17 02:18:04 +00:00
Wilson Hong
3f4dd9bfa4 Back out "[pytree] Require serialized_type_name" (#120041)
Summary:
D53785493 breaks apf.rec.ir.tests.ir_export_deserialize_test.IRExportDeserializeTest: test_export_deserialize_ebc failed:

https://www.internalfb.com/sandcastle/workflow/3436246515685789584

Test Plan: buck2 test mode/opt apf/rec/ir/tests:ir_export_deserialize_test

Differential Revision: D53834881

Co-authored-by: Wilson Hong <wilsonhong@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120041
Approved by: https://github.com/ydwu4
2024-02-16 10:02:25 +00:00
angelayi
b4c7afe101 [pytree] Require serialized_type_name (#119718)
Differential Revision: [D53785493](https://our.internmc.facebook.com/intern/diff/D53785493)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119718
Approved by: https://github.com/suo
2024-02-15 20:32:44 +00:00
Angela Yi
0827510fd3 [export] Remove torch._export.export (#119095)
XLA changes: https://github.com/pytorch/xla/pull/6486

Test Plan: CI

Differential Revision: D53316196

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119095
Approved by: https://github.com/ydwu4, https://github.com/zhxchen17, https://github.com/tugsbayasgalan, https://github.com/avikchaudhuri, https://github.com/jerryzh168
2024-02-08 21:22:04 +00:00
Michael Suo
bf4e171539 [export] support non-persistent buffers (#118969)
Summary:
X-link: https://github.com/pytorch/executorch/pull/1817

Basic support for non-persistent buffers, which are buffers that do not show up in the state dict.

One weird twist is that most of our other systems (FX, aot_export, dynamo) have completely buggy handling of non-persistent buffers. I tried to go on a wild goose chase to fix them all, but it got to be too much. So I introduced some sad rewrite passes in `_export` make the final state dict correctly align with the original module's state dict.

This exposed some bugs/ambiguous handling of parameters/buffers in existing test code. For example, `TestSaveLoad.test_save_buffer` traced over a module that was not in the root module hierarchy and caused some weird behavior. I think we should error explicitly on use cases like this: https://github.com/pytorch/pytorch/issues/118410. For now I just rewrote the tests or skipped them.

As a side effect, this diff tightened up quite a few sloppy  behaviors around state dict handling:
- Tensor attributes were getting promoted to be buffers—bad!
- Tracing through a module not in the children of the root module would add its parameters/buffers to the state dict—bad!

This behavior is unlikely to show up in user code since the model would be totally broken, but did show up in a bunch of tests.

#buildmore

Test Plan:
unit tests
sandcastle

Differential Revision: D53340041

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118969
Approved by: https://github.com/guangy10, https://github.com/huydhn, https://github.com/titaiwangms
2024-02-02 19:16:08 +00:00
Aaron Gokaslan
1562dae62c [BE]: Apply RUF025 dict.fromkeys preview rule (#118637)
Simplifies and optimizes dict construction using the `fromkeys` classmethod ctor. This also makes it really obvious when all the keys will have the same static value, which could be a bug if unintentional. It is also significantly faster than using a dict comprehension. The rule is in preview, but I am adding a forward fix for when it becomes stable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118637
Approved by: https://github.com/albanD
2024-01-30 20:46:54 +00:00
suo
4ee8aa6028 [export] adopt KeyPath API in nonstrict mode (#118609)
This PR rewrites two paths to use the newly-added keypaths API in pytree:
First: we were hand-rolling a tree_map during fakification because we wanted to track sources. This PR uses keypaths instead, which can do the same thing without needing custom code.

Second: our constraint error formatting was referencing placeholder names in error messages. These placeholder names are not otherwise user-visible, so they are super confusing to users (e.g. "which input does arg1_3 correspond to?"). This diff uses the `keystr` API to format the error message.

This necessitated some small refactors—generating the keystr is expensive so doing it in an f-string was very bad.

It can also be further improved—we can inspect the signature so that instead of `*args[0]` we can give people the actual argument name, which would be the ideal UX. But leaving that for later.

Differential Revision: [D53139358](https://our.internmc.facebook.com/intern/diff/D53139358/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118609
Approved by: https://github.com/zhxchen17
ghstack dependencies: #118607, #118608
2024-01-30 19:14:11 +00:00
Angela Yi
413a434846 [export] Convert all export tests to .module() (#118425)
Test Plan: CI

Differential Revision: D53075379

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118425
Approved by: https://github.com/suo
2024-01-29 23:06:54 +00:00
Angela Yi
5c56822be2 [export] Various fixes to .module() (#118272)
Summary: While turning on .module() for all the export tests, I uncovered some bugs with .module() and while fixing them I ended up rewriting some of the code... Some of the bugs were:

* bad kwargs support on the unlifted module
* no support for user input mutations
* (at the commit hash i was working off of) no support for custom objects
* there were no tests on unlifting weights from cond/map submodules

Test Plan: CI

Differential Revision: D53075380

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118272
Approved by: https://github.com/suo
2024-01-26 21:05:07 +00:00
Angela Yi
7dac2f9f2d [export][ez] Fix getting meta["val"] (#117313)
Summary: For integer inputs, they do not have a meta["val"].

Test Plan: `buck run @//mode/dev-nosan  //executorch/examples/portable/scripts:export -- -m emformer_predict` passes the export step

Differential Revision: D52716419

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117313
Approved by: https://github.com/kirklandsign, https://github.com/tugsbayasgalan
2024-01-12 06:17:38 +00:00
Angela Yi
8e2d63cbc3 [export][reland] Remove runtime assertion pass (#115597)
Summary:
Reland of https://github.com/pytorch/pytorch/pull/115196
D52054112 to fix internal failures.

Test Plan: CI

Differential Revision: D52054110

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115597
Approved by: https://github.com/ydwu4, https://github.com/zhxchen17
2023-12-15 03:22:03 +00:00
PyTorch MergeBot
4186932bac Revert "[export] Remove runtime assertion pass (#115196)"
This reverts commit c163b3c035.

Reverted https://github.com/pytorch/pytorch/pull/115196 on behalf of https://github.com/atalman due to Broke internal test ([comment](https://github.com/pytorch/pytorch/pull/115196#issuecomment-1847778344))
2023-12-08 20:07:04 +00:00
angelayi
c163b3c035 [export] Remove runtime assertion pass (#115196)
Reland of https://github.com/pytorch/pytorch/pull/111949/

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115196
Approved by: https://github.com/avikchaudhuri
2023-12-07 01:44:11 +00:00
Xuehai Pan
2a3d8e50fb [pytree] test aligned API signature for C++ and Python pytree (#112485)
Add tests to ensure the C++ and Python pytree provide the same APIs with identical signatures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112485
Approved by: https://github.com/zou3519
2023-11-30 17:50:06 +00:00
Xuehai Pan
89a1fe6966 [pytree] register pytree node type in both C++ pytree and Python pytree (#112111)
Changes:

1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
2023-11-28 11:41:38 +00:00
PyTorch MergeBot
01366efcc9 Revert "[pytree] register pytree node type in both C++ pytree and Python pytree (#112111)"
This reverts commit 4e4a6ad6ec.

Reverted https://github.com/pytorch/pytorch/pull/112111 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/112111#issuecomment-1824099658))
2023-11-23 09:59:32 +00:00
Xuehai Pan
4e4a6ad6ec [pytree] register pytree node type in both C++ pytree and Python pytree (#112111)
Changes:

1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
2023-11-21 19:53:13 +00:00
Tugsbayasgalan Manlaibaatar
a7b75f586a [RELAND] Disallow skipping dynamo (#110222)
Previous discussion: https://github.com/pytorch/pytorch/pull/109476

In this PR, I made following additions to the original PR:
1) Unlifted graph module now runs the runtime assertions in its' forward call.
2) When we retrace, we make sure we run the assertions to make sure user is tracing the module with correct inputs with respect to the assumptions we made during first tracing. The way I do is that I create new graph module type with modified call method. And the runtime assertions happen under torchdynamo.disable so that it is just run in eager directly. The reason is we don't this to be traced part of the graph.
3) Both ep.module and capture_pre_autograd now returns _UnliftedGraphModule.

Differential Revision: [D51078056](https://our.internmc.facebook.com/intern/diff/D51078056)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110222
Approved by: https://github.com/zhxchen17
2023-11-14 16:02:01 +00:00
PyTorch MergeBot
2a271a3efa Revert "[pytree] register pytree node type in both C++ pytree and Python pytree (#112111)"
This reverts commit a0d00349ed.

Reverted https://github.com/pytorch/pytorch/pull/112111 on behalf of https://github.com/PaliC due to _private_register_pytree_node now checks for duplicate registering, unfortunately, this breaks composability with torchrec internally :(  ([comment](https://github.com/pytorch/pytorch/pull/112111#issuecomment-1806130993))
2023-11-10 17:24:40 +00:00
Xuehai Pan
a0d00349ed [pytree] register pytree node type in both C++ pytree and Python pytree (#112111)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
2023-11-10 02:41:30 +00:00
Xuehai Pan
5e2adc8650 [pytree] align function signature between C++ and Python pytree (#112482)
Change the argument name in C++ and Python pytree APIs. Also add a test to ensure the function signatures are the same in the two implementations.

- #112485

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112482
Approved by: https://github.com/zou3519
2023-11-10 02:37:48 +00:00
PyTorch MergeBot
66150b29e3 Revert "[pytree] align function signature between C++ and Python pytree (#112482)"
This reverts commit 4893a2814f.

Reverted https://github.com/pytorch/pytorch/pull/112482 on behalf of https://github.com/PaliC due to changing _register_pytree_node's signature is bc breaking, please revert the signature and reland ([comment](https://github.com/pytorch/pytorch/pull/112482#issuecomment-1804909926))
2023-11-10 00:59:23 +00:00
PyTorch MergeBot
9a90989121 Revert "[pytree] register pytree node type in both C++ pytree and Python pytree (#112111)"
This reverts commit 95f52611c7.

Reverted https://github.com/pytorch/pytorch/pull/112111 on behalf of https://github.com/PaliC due to in the bottom diff in the stack changing _register_pytree_node's signature is bc breaking, please revert the signature and reland ([comment](https://github.com/pytorch/pytorch/pull/112111#issuecomment-1804892924))
2023-11-10 00:38:28 +00:00
Xuehai Pan
95f52611c7 [pytree] register pytree node type in both C++ pytree and Python pytree (#112111)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
2023-11-08 05:02:03 +00:00
Xuehai Pan
4893a2814f [pytree] align function signature between C++ and Python pytree (#112482)
Change the argument name in C++ and Python pytree APIs. Also add a test to ensure the function signatures are the same in the two implementations.

- #112485

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112482
Approved by: https://github.com/zou3519
2023-11-07 01:26:41 +00:00
angelayi
ff35e1e45b [pytree] Add custom treespec fqn field (#112428)
Custom classes that are serialized with pytree are serialized by default with `f”{class.__module__}.{class.__name__}”`. This is a dependency from our serialized program directly into the outer Python environment. If a user moves the class to a different directory, the serialized program will be unable to be loaded. So, we will require users to pass in an FQN if they want to serialize their custom treespec type.

Differential Revision: [D50886366](https://our.internmc.facebook.com/intern/diff/D50886366)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112428
Approved by: https://github.com/suo
2023-11-02 00:26:41 +00:00
Xuehai Pan
a7a0955790 [pytree][BE] reorganize imports and format code style and update type hints (#112268)
Reland PR:

- #112109

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112268
Approved by: https://github.com/Skylion007
2023-10-28 16:30:24 +00:00
angelayi
a432f37e49 Serialize pytree to json string (#106116)
Fixes https://github.com/pytorch/pytorch/pull/102577#issuecomment-1650905536

Serializing to json is more stable, and renamed the API:

```
# Takes in a treespec and returns the serialized treespec as a string. Also optionally takes in a protocol version number.
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
# Takes in a serialized treespec and outputs a TreeSpec
def treespec_loads(data: str) -> TreeSpec:
```

If users want to register their own serialization format for a given pytree, they can go through the `_register_treespec_serializer` API which optionally takes in a `getstate` and `setstate` function.
```
_register_treespec_serializer(type_, *, getstate, setstate)
# Takes in the context, and outputs a json-dumpable context
def getstate(context: Context) -> DumpableContext:
# Takes in a json-dumpable context, and reconstructs the original context
def setstate(dumpable_context: DumpableContext) -> Context:
```

We will serialize to the following dataclass, and then json.dump this it to string.
```
class TreeSpec
    type: Optional[str]  # a string name of the type. null for the case of a LeafSpec
    context: Optional[Any]  # optional, a json dumpable format of the context
    children_specs: List[TreeSpec],
}
```

If no getstate/setstate function is registered, we will by default serialize the context using `json.dumps/loads`. We will also serialize the type through `f"{typ.__module__}.{typ.__name__}"`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106116
Approved by: https://github.com/zou3519
2023-08-27 14:34:49 +00:00
Chen Lai
4f2ff1d019 add get buffer from exported program (#107809)
Summary: We have the util function to get params, for parity we also need util function to get buffer`

Test Plan:
```
buck test //caffe2/test:test_export
```

Differential Revision: D48610877

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107809
Approved by: https://github.com/JacobSzwejbka
2023-08-25 05:46:04 +00:00
Jacob Szwejbka
c14f4d66c3 [pytorch][export] Move is_param and get_param out of exir and into export (#107264)
Summary: These doesn't feel edge specific so moving out of exir.

Test Plan: ci

Differential Revision: D48361384

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107264
Approved by: https://github.com/angelayi
2023-08-22 21:41:51 +00:00
ydwu4
5237ed55e6 [export] allow register dataclass as pytree node (#106160)
In this pr, we allow users to register a customized flatten/unflatten/serialization/deserialization for a dataclass. We provide some default implementation for flatten/unflatten. We could implement a decorator based on it when needed.

## Motivation:
HuggingFace and many internal models return dataclass output and torch.export wants to maintain the invariant that export result (i.e. exported_program) has the same calling convention and result as the original callable.

This is not supported in export yet: we cannot recover the original dataclass from flattened output produced by the underlying graph module (produced by dynamo and processed further by aot_export). We need to have a place to store the metadata of the dataclass so that we can re-construct it. To avoid adding hacky code in export and allow princinpled extensibility, we think extending pytree may be a good option.

## Implementation:
@zou3519 mentioned https://github.com/pytorch/pytorch/pull/93214/files and [jax-2371](https://github.com/google/jax/issues/2371#issuecomment-805361566), which suggests that it's not a good idea to make dataclass a default pytree node but it could be good to provide a default implementation for dataclass. Since currently, this seems to be an export-only feature, we added this extension point in export.

We also add "return_none_fields" flag to control whether none fields are returned after flattening, which is expected to be False in produce_matching of dynamo.export.

Also added some tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106160
Approved by: https://github.com/zhxchen17
2023-07-28 17:33:13 +00:00