Commit Graph

74 Commits

Author SHA1 Message Date
Xuehai Pan
284716a691 [pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification (#160843)
The goal of this PR is to provide a standard way to create simple treespec instances and hide the implementation details of the `PyTreeSpec` class.

Changes:

1. Add function `treespec_leaf()` to replace `LeafSpec()`.
2. Add function `treespec_tuple(...)` and `treespec_dict(...)` to create treespec for `tuple` / `dict` which is used for `*args` / `**kwargs`. This avoids direct modification to `treespec` instances that rely on the implementation details of the `PyTreeSpec` class.
3. Change `len(spec.children_specs)` to `spec.num_children`.
4. Change `isinstance(spec, LeafSpec)` to `spec.is_leaf()`.

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160843
Approved by: https://github.com/mlazos
2025-10-29 09:16:24 +00:00
Xuehai Pan
2fa0520a64 [BE][pytree] cleanup parameterized pytree tests (#160842)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160842
Approved by: https://github.com/Skylion007
2025-09-05 20:15:29 +00:00
Narek Malkhasyan
30293b8b5e Preserve Enum types during torch.export serialization and deserialization (#154821)
Fixes #154674

Addresses an issue where `torch.export` does not correctly preserve Python `Enum` types during the save/load round-trip. Previously, Enum inputs were serialized by value only, causing their type to be lost after deserialization.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154821
Approved by: https://github.com/XuehaiPan, https://github.com/Skylion007, https://github.com/yushangdi, https://github.com/angelayi
2025-06-08 17:30:31 +00:00
angelayi
60fe0922f6 [pytree] Register normal class to register_dataclass (#147752)
Fixes https://github.com/pytorch/pytorch/pull/147532#discussion_r1964365330

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147752
Approved by: https://github.com/zou3519
2025-04-01 23:28:20 +00:00
Xuehai Pan
a10b765bf1 [pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)
Changes in this PR:

1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence.
2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types.
3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class.

Resolves #75982. New tests are included in this PR.

- #75982

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257
Approved by: https://github.com/zou3519
2025-04-01 10:40:43 +00:00
PyTorch MergeBot
f9b4856989 Revert "[pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)"
This reverts commit c95a6b416b.

Reverted https://github.com/pytorch/pytorch/pull/113257 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. @zou3519 can you please help land this internally? See the sigmoid tests in D71198793 for details. To validate the fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/113257#issuecomment-2725982539))
2025-03-14 23:13:34 +00:00
Xuehai Pan
c95a6b416b [pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)
Changes in this PR:

1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence.
2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types.
3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class.

Resolves #75982. New tests are included in this PR.

- #75982

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257
Approved by: https://github.com/zou3519
2025-03-14 08:50:30 +00:00
PyTorch MergeBot
ebd087e4b5 Revert "[pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)"
This reverts commit f08146b67b.

Reverted https://github.com/pytorch/pytorch/pull/113257 on behalf of https://github.com/jovianjaison due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/113257#issuecomment-2711299830))
2025-03-10 17:19:21 +00:00
Xuehai Pan
098494e9cb [dynamo] allow global import from collections import deque in user code (#148676)
See https://github.com/pytorch/pytorch/pull/148669#discussion_r1983462218 for more details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148676
Approved by: https://github.com/jansel
2025-03-10 13:14:05 +00:00
PyTorch MergeBot
19a39a7a06 Revert "[dynamo] allow global import from collections import deque in user code (#148676)"
This reverts commit 685fb37713.

Reverted https://github.com/pytorch/pytorch/pull/148676 on behalf of https://github.com/malfet due to Looks like it broke ROCM, see f1444f006c/1(default%2C%201&mergeLF=true ([comment](https://github.com/pytorch/pytorch/pull/148676#issuecomment-2709057326))
2025-03-09 20:42:03 +00:00
Xuehai Pan
685fb37713 [dynamo] allow global import from collections import deque in user code (#148676)
See https://github.com/pytorch/pytorch/pull/148669#discussion_r1983462218 for more details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148676
Approved by: https://github.com/jansel
2025-03-09 09:35:29 +00:00
Xuehai Pan
f08146b67b [pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257)
Changes in this PR:

1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence.
2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types.
3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class.

Resolves #75982. New tests are included in this PR.

- #75982

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257
Approved by: https://github.com/zou3519
2025-03-06 18:59:02 +00:00
Xuehai Pan
097b0d372a [pytree] fix previously failed dynamo tests (#148669)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148669
Approved by: https://github.com/zou3519
2025-03-06 17:59:29 +00:00
Ryan Guo
ad9a10aff0 [dynamo] Make nonstrict_trace work with some pytree.register_constant-ed instances (#148007)
As title, this enables `nonstrict_trace`-ed function to take in object
whose type has been `pytree.register_constant`-ed, as long as the object
existed outside the `torch.compile` region. This also forces Dynamo to
emit a `EQUALS_MATCH` guard on the object.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148007
Approved by: https://github.com/zou3519
ghstack dependencies: #148385
2025-03-05 21:28:26 +00:00
Xuehai Pan
9abaaad6a8 [pytree][Easy] preserve dict keys in insertion order in CXX pytree (#130140)
`optree` and JAX pytree traversal the `dict` in sorted key ordering (see [Key Ordering for Dictionaries](https://github.com/metaopt/optree#key-ordering-for-dictionaries)). While in PyTorch Python pytree, we traversal the `dict` in insertion order. See also:

- #114392

This aligns the behavior of CXX pytree with Python pytree.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130140
Approved by: https://github.com/zou3519
2025-02-12 16:41:49 +00:00
rzou
0f768c7866 Barebones flat_apply HOP (#146060)
This PR:
- adds pytree.register_constant for registering a class to be treated as
  a constant by torch.compile/torch.fx
- adds a very barebones flat_apply HOP. This should be sufficient to get
  mark_traceable working. A lot more work is necessary to get the custom
  operator case working (when make_fx sees a custom operator with PyTree
  arg types, it needs to emit a call to the flat_apply HOP).
- I expect the flat_apply HOP to change a lot, I want to ship this in
  the current state to unblock the mark_traceable and custom ops
  work.

Test Plan:
- It's kind of difficult to test the barebones flat_apply HOP "works" so
  I added a really simple test.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146060
Approved by: https://github.com/StrongerXi, https://github.com/yanboliang
ghstack dependencies: #146059
2025-02-01 16:17:48 +00:00
rzou
373606928b Add torch.utils._pytree.register_dataclass (#146059)
This is an API that registers a dataclass as a pytree node.
It directly calls torch.export.register_dataclass, but we should
eventually inline that implementation here. I want to use this API for
something in compile and feel weird calling
torch.export.register_dataclass.

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146059
Approved by: https://github.com/StrongerXi, https://github.com/angelayi, https://github.com/yanboliang
2025-02-01 16:17:48 +00:00
Henry Hu
f013cfee38 [TreeSpec] Support enum in defaultdict (#144235)
Summary: Followup from D66269157, add support for enum in defaultdict.

Test Plan: Added unit test

Differential Revision: D67832100

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144235
Approved by: https://github.com/henrylhtsang, https://github.com/houseroad
2025-01-07 00:10:46 +00:00
Tom Ritchford
d8c8ba2440 Fix unused Python variables in test/[e-z]* (#136964)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136964
Approved by: https://github.com/justinchuby, https://github.com/albanD
2024-12-18 23:02:30 +00:00
Xuehai Pan
7edeb1005a [dynamo][pytree][2/N] make CXX pytree traceable: tree_flatten / tree_unflatten / tree_structure (#137398)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137398
Approved by: https://github.com/jansel
2024-12-12 18:05:25 +00:00
Henry Tsang
30d907c6fb When serializing treespec context, support enum as well (#141525)
Following https://github.com/pytorch/pytorch/pull/102716, per @angelayi's suggestion.

Note that in general enum as an input is not supported.

repro:
```
class TestEnum(enum.Enum):
    A = auto()
    B = auto()

    @staticmethod
    def from_string(s):
        return TestEnum[s.upper()]

class M(torch.nn.Module):
    def forward(self, x, en):
        return x.clone()

input1 = (
    torch.rand(10, device="cuda"),
    {TestEnum.A: torch.rand(10, device="cuda")},
)
inputs = [input1]
model = M().cuda()

_ = model(*input1)

ep = torch.export.export(model, input1, strict=False)
path = torch._inductor.aot_compile(ep.module(), input1)
```

Differential Revision: D66269157
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141525
Approved by: https://github.com/angelayi
2024-12-04 03:08:50 +00:00
Xuehai Pan
4226ed1585 [BE] Format uncategorized Python files with ruff format (#132576)
Remove patterns `**`, `test/**`, and `torch/**` in `tools/linter/adapters/pyfmt_linter.py` and run `lintrunner`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132576
Approved by: https://github.com/ezyang, https://github.com/Skylion007
ghstack dependencies: #132574
2024-08-04 17:13:31 +00:00
rzou
480ae51f85 [pytree] Only import optree if it's used (#131478)
torch.utils._pytree imports optree if it's available. Instead, we change
it to if it gets used. The motivation for this is better isolation.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131478
Approved by: https://github.com/albanD
2024-07-24 00:10:49 +00:00
Xuehai Pan
ba48cf6535 [BE][Easy][6/19] enforce style for empty lines in import segments in test/ (#129757)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129757
Approved by: https://github.com/ezyang
2024-07-17 06:42:37 +00:00
Xuehai Pan
67ef2683d9 [BE] wrap deprecated function/class with typing_extensions.deprecated (#127689)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

Resolves #126888

- #126888

This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
2024-06-02 12:30:43 +00:00
PyTorch MergeBot
033e733021 Revert "[BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)"
This reverts commit 749a132fb0.

Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456))
2024-05-31 19:47:24 +00:00
Xuehai Pan
749a132fb0 [BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.

Resolves #126888

- #126888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
2024-05-29 12:09:27 +00:00
Angela Yi
1be2126ff6 [pytree] Fix namedtuple serialization (#123388)
Summary:
Previously we were serializing namedtuple treespecs incorrectly:
```python
Point = namedtuple("Point", ["x", "y"])
p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)

print(flat)  # [1, 2]
print(spec)  # TreeSpec(type=namedtuple, context=Point, children=[*, *])

dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
We only serialize the name of the class and the fields of the namedtuple:

TreeSpec {
  type='collections.namedtuple',
  context={class_name='Point', class_fields={'x', 'y'}},
  children=[Leaf, Leaf]
}
"""

reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec)
"""
When we load, we create a new namedtuple class containing the same fields as before,
but the is class is now a completely different class than the original one:

TreeSpec(type=namedtuple, context=torch.utils._pytree.Point, children=[*, *])
"""

spec == reconstructed_spec  # False
```

So, we introduce a new API called `pytree._register_namedtuple` where users can pass in the serialized name for each namedtuple class:
```python
Point = namedtuple("Point", ["x", "y"])
pytree._register_namedtuple(Point, "Point")

p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)

print(flat)  # [1, 2]
print(spec)  # TreeSpec(type=namedtuple, context=Point, children=[*, *])

dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
TreeSpec {
  type='collections.namedtuple',
  context='Point',
  children=[Leaf, Leaf]
}
"""

reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec)  # TreeSpec(type=namedtuple, context=Point, children=[*, *])

spec == reconstructed_spec  # True
```

Test Plan: `python test/test_pytree.py`

Differential Revision: D55771058

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123388
Approved by: https://github.com/zou3519
2024-04-08 20:55:19 +00:00
angelayi
cbbc309cae [pytree][reland] Require pytree serialized_type_name (#120636)
Relanding https://github.com/pytorch/pytorch/pull/119718 as the diff which prevents breakages of torchrec [D53857843](https://www.internalfb.com/diff/D53857843) has landed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120636
Approved by: https://github.com/avikchaudhuri
2024-02-27 06:53:33 +00:00
Xuehai Pan
be0ee93467 [pytree] support X | Y union type in tree_map_only (#120389)
Follow-up PR for #119974 with some small tweaks.

1. Support `X | Y` union type for Python 3.10+
2. Enable predicate function in `tree_map_only` in CXX pytree.
3. Remove unnecessary function definition.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120389
Approved by: https://github.com/zou3519
2024-02-22 18:17:13 +00:00
soulitzer
2e77629b9f [pytrees] Allow tree_map_only to support predicate function as filter (#119974)
In many places in the code we use `tree_map_only((SymInt, SymBool, SymFloat), foo)` but with nested ints, it is possible to have SymInts that are non-symbolic, so we may want to do something like `tree_map_only(is_symbolic, foo)` instead.

Alternative: wrap nested int SymNodes with something other than SymInt.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119974
Approved by: https://github.com/zou3519
ghstack dependencies: #119661
2024-02-21 21:10:02 +00:00
Wilson Hong
3f4dd9bfa4 Back out "[pytree] Require serialized_type_name" (#120041)
Summary:
D53785493 breaks apf.rec.ir.tests.ir_export_deserialize_test.IRExportDeserializeTest: test_export_deserialize_ebc failed:

https://www.internalfb.com/sandcastle/workflow/3436246515685789584

Test Plan: buck2 test mode/opt apf/rec/ir/tests:ir_export_deserialize_test

Differential Revision: D53834881

Co-authored-by: Wilson Hong <wilsonhong@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120041
Approved by: https://github.com/ydwu4
2024-02-16 10:02:25 +00:00
angelayi
b4c7afe101 [pytree] Require serialized_type_name (#119718)
Differential Revision: [D53785493](https://our.internmc.facebook.com/intern/diff/D53785493)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119718
Approved by: https://github.com/suo
2024-02-15 20:32:44 +00:00
Aaron Gokaslan
1562dae62c [BE]: Apply RUF025 dict.fromkeys preview rule (#118637)
Simplifies and optimizes dict construction using the `fromkeys` classmethod ctor. This also makes it really obvious when all the keys will have the same static value, which could be a bug if unintentional. It is also significantly faster than using a dict comprehension. The rule is in preview, but I am adding a forward fix for when it becomes stable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118637
Approved by: https://github.com/albanD
2024-01-30 20:46:54 +00:00
suo
e732adf0a7 [pytree] add access api (#117771)
This PR introduces an API to use KeyPaths to actually access values on pytrees.

Differential Revision: [D52881260](https://our.internmc.facebook.com/intern/diff/D52881260/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117771
Approved by: https://github.com/zou3519, https://github.com/XuehaiPan
2024-01-20 04:03:26 +00:00
suo
9448065061 [pytree] add key path api (#116786)
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).

I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.

Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.

I'm sure there are places it would be useful.

Some design notes:
- I only implemented the API for  the Python pytree impl. optree has some differences in how their keypath APIs are designed (see https://github.com/pytorch/pytorch/issues/113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.

Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116786
Approved by: https://github.com/voznesenskym
2024-01-17 07:24:35 +00:00
Xuehai Pan
ab1ac43752 [pytree] extend pytree operations with is_leaf prediction function (#116419)
Add an extra `is_leaf` prediction function to pytree operations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116419
Approved by: https://github.com/zou3519
2024-01-09 19:50:08 +00:00
suo
902807a86d enable pytree tests in fbcode (#116787)
these were not runnable before

Differential Revision: [D52547846](https://our.internmc.facebook.com/intern/diff/D52547846/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116787
Approved by: https://github.com/zou3519
2024-01-09 19:12:43 +00:00
Xuehai Pan
36c6c0c7dc [pytree] expand tree_map to accept multi-inputs (#115642)
Fixes #115419
Fixes #91323
Closes #115549

- #115419
- #91323

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115642
Approved by: https://github.com/vmoens, https://github.com/zou3519
2023-12-14 06:16:42 +00:00
Xuehai Pan
ec124b90b8 [pytree] hardcode values for none_is_leaf and namespace in C++ pytree (#114858)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114858
Approved by: https://github.com/zou3519
2023-12-01 15:01:33 +00:00
Xuehai Pan
d6c0d1b58b [pytree] support collections.deque type for Python pytree (#113256)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113256
Approved by: https://github.com/zou3519
ghstack dependencies: #112485, #113255
2023-12-01 05:12:09 +00:00
Xuehai Pan
2ab2e8e1c0 [pytree] support collections.defaultdict type for Python pytree (#113255)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113255
Approved by: https://github.com/zou3519
ghstack dependencies: #112485
2023-11-30 20:46:25 +00:00
Xuehai Pan
2a3d8e50fb [pytree] test aligned API signature for C++ and Python pytree (#112485)
Add tests to ensure the C++ and Python pytree provide the same APIs with identical signatures.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112485
Approved by: https://github.com/zou3519
2023-11-30 17:50:06 +00:00
Xuehai Pan
89a1fe6966 [pytree] register pytree node type in both C++ pytree and Python pytree (#112111)
Changes:

1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
2023-11-28 11:41:38 +00:00
PyTorch MergeBot
01366efcc9 Revert "[pytree] register pytree node type in both C++ pytree and Python pytree (#112111)"
This reverts commit 4e4a6ad6ec.

Reverted https://github.com/pytorch/pytorch/pull/112111 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/112111#issuecomment-1824099658))
2023-11-23 09:59:32 +00:00
Xuehai Pan
4e4a6ad6ec [pytree] register pytree node type in both C++ pytree and Python pytree (#112111)
Changes:

1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
2023-11-21 19:53:13 +00:00
PyTorch MergeBot
23e0923c74 Revert "[pytree] reorganize submodule structure for C++ and Python pytree (#112278)"
This reverts commit eeeb40b327.

Reverted https://github.com/pytorch/pytorch/pull/112278 on behalf of https://github.com/PaliC due to Reverting this pr as the one under it in the stack is causing regressions in torchrec ([comment](https://github.com/pytorch/pytorch/pull/112278#issuecomment-1806044435))
2023-11-10 16:30:36 +00:00
Xuehai Pan
eeeb40b327 [pytree] reorganize submodule structure for C++ and Python pytree (#112278)
Reorganized the two C++ and Python pytree submodules into a subpackage. I think this would be easier to implement the abstract `PyTreeAPI` class with two implementations. And it will be much easier for the user to switch between the two implementations.

Before:

```text
torch
├── utils
│   ├── _pytree.py
│   ├── _cxx_pytree.py
│   ...
...
```

After:

```text
torch
├── utils
│   ├── _pytree
│   │   ├── __init__.py
│   │   └── api
│   │       ├── __init__.py
│   │       ├── cxx.py
│   │       └── python.py
│   ...
...
```

The `torch.utils._pytree` module will import all APIs from `torch.utils._pytree.api.python`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112278
Approved by: https://github.com/zou3519
ghstack dependencies: #112111
2023-11-10 05:41:32 +00:00
PyTorch MergeBot
bf452dcde6 Revert "[pytree] reorganize submodule structure for C++ and Python pytree (#112278)"
This reverts commit fa895da968.

Reverted https://github.com/pytorch/pytorch/pull/112278 on behalf of https://github.com/PaliC due to in the bottom diff in the stack changing _register_pytree_node's signature is bc breaking, please revert the signature and reland ([comment](https://github.com/pytorch/pytorch/pull/112278#issuecomment-1804870560))
2023-11-10 00:12:52 +00:00
Xuehai Pan
fa895da968 [pytree] reorganize submodule structure for C++ and Python pytree (#112278)
Reorganized the two C++ and Python pytree submodules into a subpackage. I think this would be easier to implement the abstract `PyTreeAPI` class with two implementations. And it will be much easier for the user to switch between the two implementations.

Before:

```text
torch
├── utils
│   ├── _pytree.py
│   ├── _cxx_pytree.py
│   ...
...
```

After:

```text
torch
├── utils
│   ├── _pytree
│   │   ├── __init__.py
│   │   └── api
│   │       ├── __init__.py
│   │       ├── cxx.py
│   │       └── python.py
│   ...
...
```

The `torch.utils._pytree` module will import all APIs from `torch.utils._pytree.api.python`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112278
Approved by: https://github.com/zou3519
ghstack dependencies: #112111
2023-11-08 06:05:39 +00:00