Changes:
1. Bump `ruff` from 0.7.4 to 0.8.4
2. Change `%`-formatted strings to f-string
3. Change arguments with the `__`-prefix to positional-only arguments with the `/` separator in function signature.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143753
Approved by: https://github.com/Skylion007
* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
Resolves#126888
- #126888
This PR is split from PR #126898.
- #126898
------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.
Resolves#126888
- #126888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
Summary:
Previously we were serializing namedtuple treespecs incorrectly:
```python
Point = namedtuple("Point", ["x", "y"])
p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)
print(flat) # [1, 2]
print(spec) # TreeSpec(type=namedtuple, context=Point, children=[*, *])
dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
We only serialize the name of the class and the fields of the namedtuple:
TreeSpec {
type='collections.namedtuple',
context={class_name='Point', class_fields={'x', 'y'}},
children=[Leaf, Leaf]
}
"""
reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec)
"""
When we load, we create a new namedtuple class containing the same fields as before,
but the is class is now a completely different class than the original one:
TreeSpec(type=namedtuple, context=torch.utils._pytree.Point, children=[*, *])
"""
spec == reconstructed_spec # False
```
So, we introduce a new API called `pytree._register_namedtuple` where users can pass in the serialized name for each namedtuple class:
```python
Point = namedtuple("Point", ["x", "y"])
pytree._register_namedtuple(Point, "Point")
p = Point(1, 2)
flat, spec = pytree.tree_flatten(p)
print(flat) # [1, 2]
print(spec) # TreeSpec(type=namedtuple, context=Point, children=[*, *])
dumped_spec = pytree.treespec_dumps(spec)
print(dumped_spec)
"""
TreeSpec {
type='collections.namedtuple',
context='Point',
children=[Leaf, Leaf]
}
"""
reconstructed_spec = pytree.treespec_loads(dumped_spec)
print(reconstructed_spec) # TreeSpec(type=namedtuple, context=Point, children=[*, *])
spec == reconstructed_spec # True
```
Test Plan: `python test/test_pytree.py`
Differential Revision: D55771058
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123388
Approved by: https://github.com/zou3519
In many places in the code we use `tree_map_only((SymInt, SymBool, SymFloat), foo)` but with nested ints, it is possible to have SymInts that are non-symbolic, so we may want to do something like `tree_map_only(is_symbolic, foo)` instead.
Alternative: wrap nested int SymNodes with something other than SymInt.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119974
Approved by: https://github.com/zou3519
ghstack dependencies: #119661
Fixes#119768
- #119768
This PR adds a new function `tree_iter` that lazily iterates over the tree leaves. It is different than the `tree_leaves` function while the latter traversal the whole tree first to build a list of leaves.
```python
for leaf in tree_iter(tree):
...
```
is much more efficient than:
```python
for leaf in tree_leaves(tree):
...
```
where `tree_leaves(tree)` is `list(tree_iter(tree))`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120155
Approved by: https://github.com/vmoens
This PR introduces a key path API to pytrees, drawing direct inspiration from JAX's [key path API](https://jax.readthedocs.io/en/latest/jax-101/05.1-pytrees.html#key-paths).
I added the 3 APIs described there, and a registry of `flatten_with_keys` fns for each node type, which is a version of `flatten` that also returns `KeyEntry`s describing how to access values from the original pytree.
Current use cases for this API:
- Folks would like to do argument traversal over input pytrees to do verification and compatibility enforcement. Keypaths are useful for this—https://fburl.com/code/06p7zrvr is a handrolled pass doing basically the same thing but probably more fragilely.
- In export non-strict mode, we need to figure out a way to track sources for pytree inputs. In strict mode, dynamo handles this for us, but we'd like a decoupled component to handle this when we're not using dynamo.
I'm sure there are places it would be useful.
Some design notes:
- I only implemented the API for the Python pytree impl. optree has some differences in how their keypath APIs are designed (see https://github.com/pytorch/pytorch/issues/113378 for discussion). I have some issues with the proposed typed_path solution in that discussion and prefer JAX's API, but we can hash that out separately.
- The way folks register a `flatten_with_keys` fn is through a new kwarg to `register_pytree_node`. This follows how we do serialization fns, although the list of additional arguments is getting unwieldy.
- My impl handles pytrees with an undefined `flatten_with_keys` fn is different from JAX. I will raise an error, JAX creates a fallback keyentry.
Differential Revision: [D52547850](https://our.internmc.facebook.com/intern/diff/D52547850/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116786
Approved by: https://github.com/voznesenskym
Changes:
1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
It appears that `mypy` is now checking a few more previously-unchecked files; these files
are being found via import-following. Not sure exactly why they weren't being checked before.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114160
Approved by: https://github.com/eellison
ghstack dependencies: #114162
Changes:
1. Add `_private_register_pytree_node` API in both C++ and Python pytree. In C++ pytree, the API will only register pytree node for C++ pytree. In Python pytree, the API will only register pytree node for Python pytree.
2. Do not allow registering a type as pytree node twice in the Python pytree.
3. Add thread lock to the Python pytree node register API.
4. The old `_register_pytree_node` API will call the `_private_register_pytree_node` API and raise a deprecation warning.
5. Add a new `register_pytree_node` API to register node type in both C++ and Python implementations.
6. Add tests to ensure a warning will be raised when the old private function is called.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112111
Approved by: https://github.com/zou3519
Reorganized the two C++ and Python pytree submodules into a subpackage. I think this would be easier to implement the abstract `PyTreeAPI` class with two implementations. And it will be much easier for the user to switch between the two implementations.
Before:
```text
torch
├── utils
│ ├── _pytree.py
│ ├── _cxx_pytree.py
│ ...
...
```
After:
```text
torch
├── utils
│ ├── _pytree
│ │ ├── __init__.py
│ │ └── api
│ │ ├── __init__.py
│ │ ├── cxx.py
│ │ └── python.py
│ ...
...
```
The `torch.utils._pytree` module will import all APIs from `torch.utils._pytree.api.python`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112278
Approved by: https://github.com/zou3519
ghstack dependencies: #112111
Reorganized the two C++ and Python pytree submodules into a subpackage. I think this would be easier to implement the abstract `PyTreeAPI` class with two implementations. And it will be much easier for the user to switch between the two implementations.
Before:
```text
torch
├── utils
│ ├── _pytree.py
│ ├── _cxx_pytree.py
│ ...
...
```
After:
```text
torch
├── utils
│ ├── _pytree
│ │ ├── __init__.py
│ │ └── api
│ │ ├── __init__.py
│ │ ├── cxx.py
│ │ └── python.py
│ ...
...
```
The `torch.utils._pytree` module will import all APIs from `torch.utils._pytree.api.python`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112278
Approved by: https://github.com/zou3519
ghstack dependencies: #112111
Previously we added a change which required users to pass in a serialized name if they want to serialize a pytree so that the serialized name does not depend on the python environment. However this is currently breaking AOTInductor benchmark tests as AOTInductor will serialize the pytree into the .so for flattening/unflattening the inputs. However, the registration for those pytree types in the AOTInductor benchmarks are in the huggingface repo, so I'm not sure what's a good fix for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112748
Approved by: https://github.com/zhxchen17, https://github.com/malfet