Commit Graph

77 Commits

Author SHA1 Message Date
Aaron Gokaslan
12e95aa4ee [BE]: Apply PERF401 autofixes from ruff (#140980)
* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
2024-11-20 17:52:07 +00:00
Xuehai Pan
9bbe4a67ad [dynamo] support maxlen for collections.deque (#138194)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138194
Approved by: https://github.com/jansel, https://github.com/malfet
2024-10-30 10:08:02 +00:00
rzou
480ae51f85 [pytree] Only import optree if it's used (#131478)
torch.utils._pytree imports optree if it's available. Instead, we change
it to if it gets used. The motivation for this is better isolation.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131478
Approved by: https://github.com/albanD
2024-07-24 00:10:49 +00:00
Xuehai Pan
67ef2683d9 [BE] wrap deprecated function/class with typing_extensions.deprecated (#127689)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

Resolves #126888

- #126888

This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
2024-06-02 12:30:43 +00:00
PyTorch MergeBot
033e733021 Revert "[BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)"
This reverts commit 749a132fb0.

Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456))
2024-05-31 19:47:24 +00:00
Xuehai Pan
749a132fb0 [BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.

Resolves #126888

- #126888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
2024-05-29 12:09:27 +00:00
Xuehai Pan
3b0f6cce5c [pytree] freeze attributes of TreeSpec (#124011)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124011
Approved by: https://github.com/zou3519
2024-05-22 05:57:00 +00:00
Xuehai Pan
2e48f7b044 [pytree] add tree_iter function (#123913)
- Add a new `tree_iter` function.
- Bump `optree` version to `0.11.0` for C++ version of `tree_iter`.

This PR is split from #120300.

- #120300

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123913
Approved by: https://github.com/zou3519
2024-04-16 06:02:08 +00:00
Xuehai Pan
9bb54c7f3c [pytree] enable functools.wraps in Python pytree with dynamo (#124012)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124012
Approved by: https://github.com/Skylion007
2024-04-14 09:25:05 +00:00
Angela Yi
1be2126ff6 [pytree] Fix namedtuple serialization (#123388)
Summary:
Previously we were serializing namedtuple treespecs incorrectly:
```python
Point = namedtuple("Point", ["x", "y"])
p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)

print(flat)  # [1, 2]
print(spec)  # TreeSpec(type=namedtuple, context=Point, children=[*, *])

dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
We only serialize the name of the class and the fields of the namedtuple:

TreeSpec {
  type='collections.namedtuple',
  context={class_name='Point', class_fields={'x', 'y'}},
  children=[Leaf, Leaf]
}
"""

reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec)
"""
When we load, we create a new namedtuple class containing the same fields as before,
but the is class is now a completely different class than the original one:

TreeSpec(type=namedtuple, context=torch.utils._pytree.Point, children=[*, *])
"""

spec == reconstructed_spec  # False
```

So, we introduce a new API called `pytree._register_namedtuple` where users can pass in the serialized name for each namedtuple class:
```python
Point = namedtuple("Point", ["x", "y"])
pytree._register_namedtuple(Point, "Point")

p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)

print(flat)  # [1, 2]
print(spec)  # TreeSpec(type=namedtuple, context=Point, children=[*, *])

dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
TreeSpec {
  type='collections.namedtuple',
  context='Point',
  children=[Leaf, Leaf]
}
"""

reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec)  # TreeSpec(type=namedtuple, context=Point, children=[*, *])

spec == reconstructed_spec  # True
```

Test Plan: `python test/test_pytree.py`

Differential Revision: D55771058

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123388
Approved by: https://github.com/zou3519
2024-04-08 20:55:19 +00:00
angelayi
cbbc309cae [pytree][reland] Require pytree serialized_type_name (#120636)
Relanding https://github.com/pytorch/pytorch/pull/119718 as the diff which prevents breakages of torchrec [D53857843](https://www.internalfb.com/diff/D53857843) has landed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120636
Approved by: https://github.com/avikchaudhuri
2024-02-27 06:53:33 +00:00
Xuehai Pan
be0ee93467 [pytree] support X | Y union type in tree_map_only (#120389)
Follow-up PR for #119974 with some small tweaks.

1. Support `X | Y` union type for Python 3.10+
2. Enable predicate function in `tree_map_only` in CXX pytree.
3. Remove unnecessary function definition.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120389
Approved by: https://github.com/zou3519
2024-02-22 18:17:13 +00:00
soulitzer
2e77629b9f [pytrees] Allow tree_map_only to support predicate function as filter (#119974)
In many places in the code we use `tree_map_only((SymInt, SymBool, SymFloat), foo)` but with nested ints, it is possible to have SymInts that are non-symbolic, so we may want to do something like `tree_map_only(is_symbolic, foo)` instead.

Alternative: wrap nested int SymNodes with something other than SymInt.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119974
Approved by: https://github.com/zou3519
ghstack dependencies: #119661
2024-02-21 21:10:02 +00:00
PyTorch MergeBot
a1fc29cd78 Revert "[pytree] add function tree_iter (#120155)"
This reverts commit 372d078f36.

Reverted https://github.com/pytorch/pytorch/pull/120155 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/120155#issuecomment-1955479765))
2024-02-21 00:21:28 +00:00
Xuehai Pan
372d078f36 [pytree] add function tree_iter (#120155)
Fixes #119768

- #119768

This PR adds a new function `tree_iter` that lazily iterates over the tree leaves. It is different than the `tree_leaves` function while the latter traversal the whole tree first to build a list of leaves.

```python
for leaf in tree_iter(tree):
    ...
```

is much more efficient than:

```python
for leaf in tree_leaves(tree):
    ...
```

where `tree_leaves(tree)` is `list(tree_iter(tree))`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120155
Approved by: https://github.com/vmoens
2024-02-18 09:16:50 +00:00
Wilson Hong
3f4dd9bfa4 Back out "[pytree] Require serialized_type_name" (#120041)
Summary:
D53785493 breaks apf.rec.ir.tests.ir_export_deserialize_test.IRExportDeserializeTest: test_export_deserialize_ebc failed:

https://www.internalfb.com/sandcastle/workflow/3436246515685789584

Test Plan: buck2 test mode/opt apf/rec/ir/tests:ir_export_deserialize_test

Differential Revision: D53834881

Co-authored-by: Wilson Hong <wilsonhong@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120041
Approved by: https://github.com/ydwu4
2024-02-16 10:02:25 +00:00
angelayi
b4c7afe101 [pytree] Require serialized_type_name (#119718)
Differential Revision: [D53785493](https://our.internmc.facebook.com/intern/diff/D53785493)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119718
Approved by: https://github.com/suo
2024-02-15 20:32:44 +00:00
Xuehai Pan
42062e2622 [pytree][BE] update treespec is_leaf() access (#116371)
Change `isinstance(treespec, LeafSpec) -> treespec.is_leaf()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116371
Approved by: https://github.com/zou3519
2024-01-27 11:44:57 +00:00
Guilherme Leobas
80cf0ce153 Enhance torch.vmap support from inside torch.compile (#116050)
This work rewrites vmap support in torch.compile by inlining most of
the frames into the existing FX graph. It also unlocks to PyTorch to
support features that were previously missing, such as keyword args.

Fixes: https://github.com/pytorch/pytorch/issues/114306

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116050
Approved by: https://github.com/zou3519
2024-01-22 17:53:45 +00:00
suo
e732adf0a7 [pytree] add access api (#117771)
This PR introduces an API to use KeyPaths to actually access values on pytrees.

Differential Revision: [D52881260](https://our.internmc.facebook.com/intern/diff/D52881260/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117771
Approved by: https://github.com/zou3519, https://github.com/XuehaiPan
2024-01-20 04:03:26 +00:00
Xuehai Pan
c0940d2e93 [pytree] reuse flatten_fn in flatten_with_keys_fn to ensure consistency (#117656)
Reuse `flatten_fn` in `flatten_with_keys_fn` to ensure `flatten_fn` and `flatten_with_keys_fn` get the same `leaves` and `context`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117656
Approved by: https://github.com/suo
2024-01-17 20:38:49 +00:00
suo
9448065061 [pytree] add key path api (#116786)
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see https://github.com/pytorch/pytorch/issues/113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116786
Approved by: https://github.com/voznesenskym
2024-01-17 07:24:35 +00:00
Xuehai Pan
ab1ac43752 [pytree] extend pytree operations with is_leaf prediction function (#116419)
Add an extra `is_leaf` prediction function to pytree operations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116419
Approved by: https://github.com/zou3519
2024-01-09 19:50:08 +00:00
Xuehai Pan
199e07f108 [pytree][BE] update treespec num_children access (#116370)
Change `len(treespec.children_spes) -> treespec.num_children`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116370
Approved by: https://github.com/Skylion007
2023-12-24 20:54:32 +00:00
Xuehai Pan
36c6c0c7dc [pytree] expand tree_map to accept multi-inputs (#115642)
Fixes #115419
Fixes #91323
Closes #115549

- #115419
- #91323

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115642
Approved by: https://github.com/vmoens, https://github.com/zou3519
2023-12-14 06:16:42 +00:00
Xuehai Pan
d6c0d1b58b [pytree] support collections.deque type for Python pytree (#113256)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113256
Approved by: https://github.com/zou3519
ghstack dependencies: #112485, #113255
2023-12-01 05:12:09 +00:00
Xuehai Pan
2ab2e8e1c0 [pytree] support collections.defaultdict type for Python pytree (#113255)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113255
Approved by: https://github.com/zou3519
ghstack dependencies: #112485
2023-11-30 20:46:25 +00:00
Xuehai Pan
2a3d8e50fb [pytree] test aligned API signature for C++ and Python pytree (#112485)
Add tests to ensure the C++ and Python pytree provide the same APIs with identical signatures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112485
Approved by: https://github.com/zou3519
2023-11-30 17:50:06 +00:00
Xuehai Pan
89a1fe6966 [pytree] register pytree node type in both C++ pytree and Python pytree (#112111)
Changes:

1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
2023-11-28 11:41:38 +00:00
Jez Ng
5cfa0647a7 Update mypy to 1.7.0 (#114160)
It appears that `mypy` is now checking a few more previously-unchecked files; these files
are being found via import-following. Not sure exactly why they weren't being checked before.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114160
Approved by: https://github.com/eellison
ghstack dependencies: #114162
2023-11-28 06:45:55 +00:00
PyTorch MergeBot
01366efcc9 Revert "[pytree] register pytree node type in both C++ pytree and Python pytree (#112111)"
This reverts commit 4e4a6ad6ec.

Reverted https://github.com/pytorch/pytorch/pull/112111 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/112111#issuecomment-1824099658))
2023-11-23 09:59:32 +00:00
Xuehai Pan
4e4a6ad6ec [pytree] register pytree node type in both C++ pytree and Python pytree (#112111)
Changes:

1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
2023-11-21 19:53:13 +00:00
PyTorch MergeBot
2a271a3efa Revert "[pytree] register pytree node type in both C++ pytree and Python pytree (#112111)"
This reverts commit a0d00349ed.

Reverted https://github.com/pytorch/pytorch/pull/112111 on behalf of https://github.com/PaliC due to _private_register_pytree_node now checks for duplicate registering, unfortunately, this breaks composability with torchrec internally :(  ([comment](https://github.com/pytorch/pytorch/pull/112111#issuecomment-1806130993))
2023-11-10 17:24:40 +00:00
PyTorch MergeBot
23e0923c74 Revert "[pytree] reorganize submodule structure for C++ and Python pytree (#112278)"
This reverts commit eeeb40b327.

Reverted https://github.com/pytorch/pytorch/pull/112278 on behalf of https://github.com/PaliC due to Reverting this pr as the one under it in the stack is causing regressions in torchrec ([comment](https://github.com/pytorch/pytorch/pull/112278#issuecomment-1806044435))
2023-11-10 16:30:36 +00:00
Xuehai Pan
eeeb40b327 [pytree] reorganize submodule structure for C++ and Python pytree (#112278)
Reorganized the two C++ and Python pytree submodules into a subpackage. I think this would be easier to implement the abstract `PyTreeAPI` class with two implementations. And it will be much easier for the user to switch between the two implementations.

Before:

```text
torch
├── utils
│   ├── _pytree.py
│   ├── _cxx_pytree.py
│   ...
...
```

After:

```text
torch
├── utils
│   ├── _pytree
│   │   ├── __init__.py
│   │   └── api
│   │       ├── __init__.py
│   │       ├── cxx.py
│   │       └── python.py
│   ...
...
```

The `torch.utils._pytree` module will import all APIs from `torch.utils._pytree.api.python`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112278
Approved by: https://github.com/zou3519
ghstack dependencies: #112111
2023-11-10 05:41:32 +00:00
Xuehai Pan
a0d00349ed [pytree] register pytree node type in both C++ pytree and Python pytree (#112111)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
2023-11-10 02:41:30 +00:00
Xuehai Pan
5e2adc8650 [pytree] align function signature between C++ and Python pytree (#112482)
Change the argument name in C++ and Python pytree APIs. Also add a test to ensure the function signatures are the same in the two implementations.

- #112485

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112482
Approved by: https://github.com/zou3519
2023-11-10 02:37:48 +00:00
PyTorch MergeBot
66150b29e3 Revert "[pytree] align function signature between C++ and Python pytree (#112482)"
This reverts commit 4893a2814f.

Reverted https://github.com/pytorch/pytorch/pull/112482 on behalf of https://github.com/PaliC due to changing _register_pytree_node's signature is bc breaking, please revert the signature and reland ([comment](https://github.com/pytorch/pytorch/pull/112482#issuecomment-1804909926))
2023-11-10 00:59:23 +00:00
PyTorch MergeBot
9a90989121 Revert "[pytree] register pytree node type in both C++ pytree and Python pytree (#112111)"
This reverts commit 95f52611c7.

Reverted https://github.com/pytorch/pytorch/pull/112111 on behalf of https://github.com/PaliC due to in the bottom diff in the stack changing _register_pytree_node's signature is bc breaking, please revert the signature and reland ([comment](https://github.com/pytorch/pytorch/pull/112111#issuecomment-1804892924))
2023-11-10 00:38:28 +00:00
PyTorch MergeBot
bf452dcde6 Revert "[pytree] reorganize submodule structure for C++ and Python pytree (#112278)"
This reverts commit fa895da968.

Reverted https://github.com/pytorch/pytorch/pull/112278 on behalf of https://github.com/PaliC due to in the bottom diff in the stack changing _register_pytree_node's signature is bc breaking, please revert the signature and reland ([comment](https://github.com/pytorch/pytorch/pull/112278#issuecomment-1804870560))
2023-11-10 00:12:52 +00:00
Xuehai Pan
fa895da968 [pytree] reorganize submodule structure for C++ and Python pytree (#112278)
Reorganized the two C++ and Python pytree submodules into a subpackage. I think this would be easier to implement the abstract `PyTreeAPI` class with two implementations. And it will be much easier for the user to switch between the two implementations.

Before:

```text
torch
├── utils
│   ├── _pytree.py
│   ├── _cxx_pytree.py
│   ...
...
```

After:

```text
torch
├── utils
│   ├── _pytree
│   │   ├── __init__.py
│   │   └── api
│   │       ├── __init__.py
│   │       ├── cxx.py
│   │       └── python.py
│   ...
...
```

The `torch.utils._pytree` module will import all APIs from `torch.utils._pytree.api.python`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112278
Approved by: https://github.com/zou3519
ghstack dependencies: #112111
2023-11-08 06:05:39 +00:00
Xuehai Pan
95f52611c7 [pytree] register pytree node type in both C++ pytree and Python pytree (#112111)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
2023-11-08 05:02:03 +00:00
Xuehai Pan
4893a2814f [pytree] align function signature between C++ and Python pytree (#112482)
Change the argument name in C++ and Python pytree APIs. Also add a test to ensure the function signatures are the same in the two implementations.

- #112485

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112482
Approved by: https://github.com/zou3519
2023-11-07 01:26:41 +00:00
angelayi
3904b81420 [pytree] Add back a default serialized name (#112748)
Previously we added a change which required users to pass in a serialized name if they want to serialize a pytree so that the serialized name does not depend on the python environment. However this is currently breaking AOTInductor benchmark tests as AOTInductor will serialize the pytree into the .so for flattening/unflattening the inputs. However, the registration for those pytree types in the AOTInductor benchmarks are in the huggingface repo, so I'm not sure what's a good fix for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112748
Approved by: https://github.com/zhxchen17, https://github.com/malfet
2023-11-02 22:34:42 +00:00
PyTorch MergeBot
ca33dd780e Revert "[pytree] Add back a default serialized name (#112748)"
This reverts commit ca72d23613.

Reverted https://github.com/pytorch/pytorch/pull/112748 on behalf of https://github.com/angelayi due to sorry, was trying to fix CI and broke CI ([comment](https://github.com/pytorch/pytorch/pull/112748#issuecomment-1791098635))
2023-11-02 16:47:59 +00:00
angelayi
ca72d23613 [pytree] Add back a default serialized name (#112748)
Previously we added a change which required users to pass in a serialized name if they want to serialize a pytree so that the serialized name does not depend on the python environment. However this is currently breaking AOTInductor benchmark tests as AOTInductor will serialize the pytree into the .so for flattening/unflattening the inputs. However, the registration for those pytree types in the AOTInductor benchmarks are in the huggingface repo, so I'm not sure what's a good fix for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112748
Approved by: https://github.com/zhxchen17, https://github.com/malfet
2023-11-02 16:18:03 +00:00
angelayi
ff35e1e45b [pytree] Add custom treespec fqn field (#112428)
Custom classes that are serialized with pytree are serialized by default with `f”{class.__module__}.{class.__name__}”`. This is a dependency from our serialized program directly into the outer Python environment. If a user moves the class to a different directory, the serialized program will be unable to be loaded. So, we will require users to pass in an FQN if they want to serialize their custom treespec type.

Differential Revision: [D50886366](https://our.internmc.facebook.com/intern/diff/D50886366)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112428
Approved by: https://github.com/suo
2023-11-02 00:26:41 +00:00
Peter Bell
046c0c66fd [pytree] Add arg_tree_leaves to optimize flattening function arguments (#112393)
We commonly do some variation of `tree_leaves((args, kwargs))`. This adds a new
function `arg_tree_leaves(*args, **kwargs)` which takes advantage of the known
structure of `args` and `kwargs` to skip their `flatten_fn`.

I see ~1 us improvement per call for args + kwargs, or a 0.5 us improvement
when passing just one of `args` or `kwargs`. For shallow structures, this can be
proportionally quite significant. For example, the empty_strided call I've been
using as a benchmark:
```
args = ((100, 100), (100, 1))
kwargs = dict(device="cuda")
```
Sees a 30% speedup from this.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112393
Approved by: https://github.com/lezcano
ghstack dependencies: #112391, #112392
2023-10-31 15:57:00 +00:00
Peter Bell
31c0ef934b [pytree] Remove LeafSpec construction cost in tree_flatten (#112392)
On my machine, `pytree.LeafSpec()` takes ~600ns but since every leaf spec is the
same, we can just use a global constant.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112392
Approved by: https://github.com/lezcano
ghstack dependencies: #112391
2023-10-30 21:45:45 +00:00
Peter Bell
0f2b7a99e3 [pytree] Avoid constructing intermediate lists in tree_{flatten,leaves} (#112391)
Instead of concatenating lists of child nodes, this appends the leaf nodes
directly onto the list of leaves to be returned which gives a small perf
improvement.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112391
Approved by: https://github.com/zou3519
2023-10-30 21:45:45 +00:00