Commit Graph

11 Commits

Author SHA1 Message Date
Xuehai Pan
5e2adc8650 [pytree] align function signature between C++ and Python pytree (#112482)
Change the argument name in C++ and Python pytree APIs. Also add a test to ensure the function signatures are the same in the two implementations.

- #112485

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112482
Approved by: https://github.com/zou3519
2023-11-10 02:37:48 +00:00
PyTorch MergeBot
66150b29e3 Revert "[pytree] align function signature between C++ and Python pytree (#112482)"
This reverts commit 4893a2814f.

Reverted https://github.com/pytorch/pytorch/pull/112482 on behalf of https://github.com/PaliC due to changing _register_pytree_node's signature is bc breaking, please revert the signature and reland ([comment](https://github.com/pytorch/pytorch/pull/112482#issuecomment-1804909926))
2023-11-10 00:59:23 +00:00
PyTorch MergeBot
9a90989121 Revert "[pytree] register pytree node type in both C++ pytree and Python pytree (#112111)"
This reverts commit 95f52611c7.

Reverted https://github.com/pytorch/pytorch/pull/112111 on behalf of https://github.com/PaliC due to in the bottom diff in the stack changing _register_pytree_node's signature is bc breaking, please revert the signature and reland ([comment](https://github.com/pytorch/pytorch/pull/112111#issuecomment-1804892924))
2023-11-10 00:38:28 +00:00
Xuehai Pan
95f52611c7 [pytree] register pytree node type in both C++ pytree and Python pytree (#112111)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
2023-11-08 05:02:03 +00:00
Xuehai Pan
4893a2814f [pytree] align function signature between C++ and Python pytree (#112482)
Change the argument name in C++ and Python pytree APIs. Also add a test to ensure the function signatures are the same in the two implementations.

- #112485

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112482
Approved by: https://github.com/zou3519
2023-11-07 01:26:41 +00:00
angelayi
ff35e1e45b [pytree] Add custom treespec fqn field (#112428)
Custom classes that are serialized with pytree are serialized by default with `f”{class.__module__}.{class.__name__}”`. This is a dependency from our serialized program directly into the outer Python environment. If a user moves the class to a different directory, the serialized program will be unable to be loaded. So, we will require users to pass in an FQN if they want to serialize their custom treespec type.

Differential Revision: [D50886366](https://our.internmc.facebook.com/intern/diff/D50886366)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112428
Approved by: https://github.com/suo
2023-11-02 00:26:41 +00:00
Xuehai Pan
a7a0955790 [pytree][BE] reorganize imports and format code style and update type hints (#112268)
Reland PR:

- #112109

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112268
Approved by: https://github.com/Skylion007
2023-10-28 16:30:24 +00:00
angelayi
a432f37e49 Serialize pytree to json string (#106116)
Fixes https://github.com/pytorch/pytorch/pull/102577#issuecomment-1650905536

Serializing to json is more stable, and renamed the API:

```
# Takes in a treespec and returns the serialized treespec as a string. Also optionally takes in a protocol version number.
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
# Takes in a serialized treespec and outputs a TreeSpec
def treespec_loads(data: str) -> TreeSpec:
```

If users want to register their own serialization format for a given pytree, they can go through the `_register_treespec_serializer` API which optionally takes in a `getstate` and `setstate` function.
```
_register_treespec_serializer(type_, *, getstate, setstate)
# Takes in the context, and outputs a json-dumpable context
def getstate(context: Context) -> DumpableContext:
# Takes in a json-dumpable context, and reconstructs the original context
def setstate(dumpable_context: DumpableContext) -> Context:
```

We will serialize to the following dataclass, and then json.dump this it to string.
```
class TreeSpec
    type: Optional[str]  # a string name of the type. null for the case of a LeafSpec
    context: Optional[Any]  # optional, a json dumpable format of the context
    children_specs: List[TreeSpec],
}
```

If no getstate/setstate function is registered, we will by default serialize the context using `json.dumps/loads`. We will also serialize the type through `f"{typ.__module__}.{typ.__name__}"`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106116
Approved by: https://github.com/zou3519
2023-08-27 14:34:49 +00:00
Chen Lai
4f2ff1d019 add get buffer from exported program (#107809)
Summary: We have the util function to get params, for parity we also need util function to get buffer`

Test Plan:
```
buck test //caffe2/test:test_export
```

Differential Revision: D48610877

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107809
Approved by: https://github.com/JacobSzwejbka
2023-08-25 05:46:04 +00:00
Jacob Szwejbka
c14f4d66c3 [pytorch][export] Move is_param and get_param out of exir and into export (#107264)
Summary: These doesn't feel edge specific so moving out of exir.

Test Plan: ci

Differential Revision: D48361384

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107264
Approved by: https://github.com/angelayi
2023-08-22 21:41:51 +00:00
ydwu4
5237ed55e6 [export] allow register dataclass as pytree node (#106160)
In this pr, we allow users to register a customized flatten/unflatten/serialization/deserialization for a dataclass. We provide some default implementation for flatten/unflatten. We could implement a decorator based on it when needed.

## Motivation:
HuggingFace and many internal models return dataclass output and torch.export wants to maintain the invariant that export result (i.e. exported_program) has the same calling convention and result as the original callable.

This is not supported in export yet: we cannot recover the original dataclass from flattened output produced by the underlying graph module (produced by dynamo and processed further by aot_export). We need to have a place to store the metadata of the dataclass so that we can re-construct it. To avoid adding hacky code in export and allow princinpled extensibility, we think extending pytree may be a good option.

## Implementation:
@zou3519 mentioned https://github.com/pytorch/pytorch/pull/93214/files and [jax-2371](https://github.com/google/jax/issues/2371#issuecomment-805361566), which suggests that it's not a good idea to make dataclass a default pytree node but it could be good to provide a default implementation for dataclass. Since currently, this seems to be an export-only feature, we added this extension point in export.

We also add "return_none_fields" flag to control whether none fields are returned after flattening, which is expected to be False in produce_matching of dynamo.export.

Also added some tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106160
Approved by: https://github.com/zhxchen17
2023-07-28 17:33:13 +00:00