The goal of this PR is to provide a standard way to create simple treespec instances and hide the implementation details of the `PyTreeSpec` class.
Changes:
1. Add function `treespec_leaf()` to replace `LeafSpec()`.
2. Add function `treespec_tuple(...)` and `treespec_dict(...)` to create treespec for `tuple` / `dict` which is used for `*args` / `**kwargs`. This avoids direct modification to `treespec` instances that rely on the implementation details of the `PyTreeSpec` class.
3. Change `len(spec.children_specs)` to `spec.num_children`.
4. Change `isinstance(spec, LeafSpec)` to `spec.is_leaf()`.
------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160843
Approved by: https://github.com/mlazos
The goal of this PR is to provide a standard way to create simple treespec instances and hide the implementation details of the `PyTreeSpec` class.
Changes:
1. Add function `treespec_leaf()` to replace `LeafSpec()`.
2. Add function `treespec_tuple(...)` and `treespec_dict(...)` to create treespec for `tuple` / `dict` which is used for `*args` / `**kwargs`. This avoids direct modification to `treespec` instances that rely on the implementation details of the `PyTreeSpec` class.
3. Change `len(spec.children_specs)` to `spec.num_children`.
4. Change `isinstance(spec, LeafSpec)` to `spec.is_leaf()`.
------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160843
Approved by: https://github.com/mlazos
Changes in this PR:
1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence.
2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types.
3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class.
Resolves#75982. New tests are included in this PR.
- #75982
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257
Approved by: https://github.com/zou3519
Changes in this PR:
1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence.
2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types.
3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class.
Resolves#75982. New tests are included in this PR.
- #75982
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257
Approved by: https://github.com/zou3519
Changes in this PR:
1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence.
2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types.
3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class.
Resolves#75982. New tests are included in this PR.
- #75982
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257
Approved by: https://github.com/zou3519
As title, this enables `nonstrict_trace`-ed function to take in object
whose type has been `pytree.register_constant`-ed, as long as the object
existed outside the `torch.compile` region. This also forces Dynamo to
emit a `EQUALS_MATCH` guard on the object.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148007
Approved by: https://github.com/zou3519
ghstack dependencies: #148385
This PR:
- adds pytree.register_constant for registering a class to be treated as
a constant by torch.compile/torch.fx
- adds a very barebones flat_apply HOP. This should be sufficient to get
mark_traceable working. A lot more work is necessary to get the custom
operator case working (when make_fx sees a custom operator with PyTree
arg types, it needs to emit a call to the flat_apply HOP).
- I expect the flat_apply HOP to change a lot, I want to ship this in
the current state to unblock the mark_traceable and custom ops
work.
Test Plan:
- It's kind of difficult to test the barebones flat_apply HOP "works" so
I added a really simple test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146060
Approved by: https://github.com/StrongerXi, https://github.com/yanboliang
ghstack dependencies: #146059
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
Resolves#126888
- #126888
This PR is split from PR #126898.
- #126898
------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.
Resolves#126888
- #126888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
Summary:
Previously we were serializing namedtuple treespecs incorrectly:
```python
Point = namedtuple("Point", ["x", "y"])
p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)
print(flat) # [1, 2]
print(spec) # TreeSpec(type=namedtuple, context=Point, children=[*, *])
dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
We only serialize the name of the class and the fields of the namedtuple:
TreeSpec {
type='collections.namedtuple',
context={class_name='Point', class_fields={'x', 'y'}},
children=[Leaf, Leaf]
}
"""
reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec)
"""
When we load, we create a new namedtuple class containing the same fields as before,
but the is class is now a completely different class than the original one:
TreeSpec(type=namedtuple, context=torch.utils._pytree.Point, children=[*, *])
"""
spec == reconstructed_spec # False
```
So, we introduce a new API called `pytree._register_namedtuple` where users can pass in the serialized name for each namedtuple class:
```python
Point = namedtuple("Point", ["x", "y"])
pytree._register_namedtuple(Point, "Point")
p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)
print(flat) # [1, 2]
print(spec) # TreeSpec(type=namedtuple, context=Point, children=[*, *])
dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
TreeSpec {
type='collections.namedtuple',
context='Point',
children=[Leaf, Leaf]
}
"""
reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec) # TreeSpec(type=namedtuple, context=Point, children=[*, *])
spec == reconstructed_spec # True
```
Test Plan: `python test/test_pytree.py`
Differential Revision: D55771058
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123388
Approved by: https://github.com/zou3519
In many places in the code we use `tree_map_only((SymInt, SymBool, SymFloat), foo)` but with nested ints, it is possible to have SymInts that are non-symbolic, so we may want to do something like `tree_map_only(is_symbolic, foo)` instead.
Alternative: wrap nested int SymNodes with something other than SymInt.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119974
Approved by: https://github.com/zou3519
ghstack dependencies: #119661
Simplifies and optimizes dict construction using the `fromkeys` classmethod ctor. This also makes it really obvious when all the keys will have the same static value, which could be a bug if unintentional. It is also significantly faster than using a dict comprehension. The rule is in preview, but I am adding a forward fix for when it becomes stable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118637
Approved by: https://github.com/albanD
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).
I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.
Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.
I'm sure there are places it would be useful.
Some design notes:
- I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see https://github.com/pytorch/pytorch/issues/113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.
Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116786
Approved by: https://github.com/voznesenskym
Changes:
1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
Changes:
1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
Reorganized the two C++ and Python pytree submodules into a subpackage. I think this would be easier to implement the abstract `PyTreeAPI` class with two implementations. And it will be much easier for the user to switch between the two implementations.
Before:
```text
torch
├── utils
│ ├── _pytree.py
│ ├── _cxx_pytree.py
│ ...
...
```
After:
```text
torch
├── utils
│ ├── _pytree
│ │ ├── __init__.py
│ │ └── api
│ │ ├── __init__.py
│ │ ├── cxx.py
│ │ └── python.py
│ ...
...
```
The `torch.utils._pytree` module will import all APIs from `torch.utils._pytree.api.python`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112278
Approved by: https://github.com/zou3519
ghstack dependencies: #112111