Summary:
Action following https://github.com/pytorch/pytorch/issues/66232
This change does require some context: there were several suggestions regarding what to do about this group of tests: tests that are core and crucial to all of PyTorch and are too broad to be owned by one team.
1. Let's add a "module: core" and put people behind it! This idea sounds appealing unless you are one of the people backing the label. From talking to albanD among others, this idea of putting all these core tests on the shoulder of a few people or one team isn't super fair and I have not yet found anyone willing to take on this job.
2. Taking advantage of the fact that we already have a triaging oncall that takes turns triaging issues, we can leave these tests essentially unlabeled and allow the oncall to triage these tests. Since these tests are crucial to PyTorch, we'll add the "high priority" label to mark them different from other unowned tests (see https://github.com/pytorch/pytorch/issues/67552).
3. I _could_ still create an unbacked label "module: core" and attribute these tests there, but I don't like the idea of creating a facade that the tests are "triaged" to a label when no one is actually taking a look.
Now we could potentially break these tests down into smaller files so that each piece _could_ be owned by a team, but 1. I don't know if this is currently feasible and 2. This approach does not prevent that from happening in the future.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67553
Reviewed By: albanD
Differential Revision: D32025004
Pulled By: janeyx99
fbshipit-source-id: 1fb1aa4c27e305695ab6e80ae3d02f90519939c0
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62292
This PR adds pytree support for namedtuples. The challenge about namedtuple
is that each namedtuple class is actually different. This PR does the
following:
- it adds a namedtuple flatten/unflatten. The flatten function returns
a context that is the actual type of the namedtuple subclass. The
unflatten function uses that type to reconstruct the namedtuple
- Special cases all pytree logic to consider all namedtuples the same.
This is done by creating a `_get_node_type(pytree)` helper function that
returns `namedtuple` if `pytree` is any namedtuple subclass. The effect
of this is that all namedtuple subclasses will go through the namedtuple
flatten/unflatten functions
- Adds a `_namedtuple_flatten_spec` function for FX pytrees. This function
flattens the namedtuple based on the spec and is equivalent to the
`_tuple_flatten_spec`.
Test Plan
- new tests in test/test_pytree.py and test/test_fx.py
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D29947302
Pulled By: zou3519
fbshipit-source-id: 19c00665b13546642c315df0f243ad99b8e7ff7c
Summary:
```
class Foo(nn.Module):
def __init__(self):
super().__init__()
def forward(self, y, x):
for k in x:
for v in x[k]:
v += y
return x
example_dict = {'x': {'a': [fx.HOLE], 'z': [fx.HOLE, fx.HOLE]}}
new_f = fx.symbolic_trace(Foo(), concrete_args=example_dict)
print(new_f.code)
new_f(torch.randn(5), {'x': {'a': [torch.randn(5)], 'z': [torch.randn(5), torch.randn(5)]}})
fx.symbolic_trace(new_f, concrete_args=example_dict)
```
prints out
```
def forward(self, y, x):
y, tree_2, tree_3, tree_4 = pytree.tree_flatten([y, x])[0]
add = tree_2 + y
add_1 = tree_3 + y
add_2 = tree_4 + y; y = None
return {'a': [tree_2], 'z': [tree_3, tree_4]}
```
Currently, I store `in_spec` as an extra attribute on `fx.Graph`, and then include it when we do the codegen. I'm not sure if this is the right approach - it introduces a divergence between what's in `fx.Graph` and what's in the python code.
Perhaps the best API is something explicit like `fx.Graph.flatten_args`, but that does make calling things a bit ... more verbose.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55888
Reviewed By: jamesr66a
Differential Revision: D27884694
Pulled By: Chillee
fbshipit-source-id: f9e8a70c63a8df63c9f9bd0a6459255daa5a8df8
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46288
This "broadcasts" `pytree` to have the same structure as `spec`
and then flattens it.
I find it hard to describe what that does in words, so here's an example:
- Broadcasting 1 to have the same structure as [0, [0, 0]] would
return [1, [1, 1]]. Further flattening it gives us [1, 1, 1].
- Broadcasting [1, 2] to have the same structure as [0, [0, 0]] would
return [1, [2, 2]]. Further flattening it gives us [1, 2, 2].
What is this used for?
----------------------
The next PR up in the stack uses this helper function to allow vmap to
accept nested data structures. `vmap(fn, in_dims)(*inputs)` allows the
user to specify in_dims with a tree structure that is a sub-graph of
that of `inputs` (where both contain the root of the tree).
For example, one can do `vmap(fn, in_dims=0)(x, y, z)`. `in_dims` is 0
and inputs is (x, y, z). We would like to broadcast in_dims up to the
structure of inputs to get (0, 0, 0).
Another example, is `vmap(fn, in_dims=(0, 1))(x, [y, z])`. `in_dims` is
(0, 1) and inputs is (x, [y, z]). We would like to broadcast in_dims up
to the structure of inputs to get (0, [1, 1]); this value of in_dims is
used to say "let's vmap over dim 0 for x and dim 1 for y and z".
Test Plan
---------
New tests.
Test Plan: Imported from OSS
Reviewed By: heitorschueroff
Differential Revision: D24392891
Pulled By: zou3519
fbshipit-source-id: 6f494d8b6359582f1b4ab6b8dd6a956d8bfe8ed4
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46287
This adds a lightweight `pytree` implementation that is similar to and
inspired by JAX pytrees, tensorflow.nest, deepmind/tree,
TorchBeast's TensorNest, etc.
A *pytree* is Python nested data structure. It is a tree in the sense
that nodes are Python collections (e.g., list, tuple, dict) and the leaves
are Python values. Furthermore, a pytree should not contain reference
cycles.
This PR:
- adds support for flattening and unflattening nested Python list/dict/tuples
Context: nested Tensor inputs for vmap
--------------------------------------
Right now, vmap is restricted to taking in flat lists of tensors. This
is because vmap needs to be able to convert every tensor in the input
that is being vmapped over into a BatchedTensor.
With a pytree library, we can simply flatten the input data structure
(returning the leaves), map all of the Tensors in the flat input to
BatchedTensors, and unflatten the flat list of BatchedTensors into a new
input. Or equivalently, with a `tree_map` function, we can map a nested
python data structure containing Tensors into one containing
BatchedTensors.
Future work
-----------
In some future PRs, we'll add nested input support for vmap. The
prerequisites for that are:
- a `broadcast_to(small, big)` that broadcasts `small` up to `big`.
This is for handling the in_dims to vmap: the in_dims structure must
be compatible with the structure of the inputs.
Test Plan
---------
- New tests in test/test_pytree.py
Test Plan: Imported from OSS
Reviewed By: heitorschueroff
Differential Revision: D24392890
Pulled By: zou3519
fbshipit-source-id: 7daf7430c5a38354e7d203a72882bd7a9b24cfb1