This should be enough to get @voznesenskym 's FSDP branch to plumb `set_()` through AOTAutograd properly and have everything properly no-op out. Main changes are:
(1) graph break on `aten::set_.source_Tensor_storage_offset` (we could support it but it isn't needed, seems safer to graph break)
(2) Functionalization: add a "proper" functionalization kernel for `aten::set_.source_Tensor`. The previous one we had was codegen'd and it was wrong (it would just clone() and call set_(), which does not do the right thing). I also manually mark on the `FunctionalTensorWrapper` when a given tensor has been mutated by a `set_()` call.
(3) AOTAutograd: I added a new field, `InputAliasInfo.mutates_storage_metadata`, so we can distinguish between "regular" metadata mutations, and metadata mutations due to `set_()` calls. This is mainly because at runtime, one requires calling `as_strided_()` to fix up metadata, while the other requires calling `set_()`.
(4) Made AOTAutograd's detection for metadata mutations / set_() mutations smarter and detect no-ops (if the storage and metadata are all the same).
I also killed `was_updated()` and `was_metadata_updated()`, and replaced them with (existing) `has_data_mutation() ` and (new) `has_data_mutation()`, which can more accurately distinguish between data-mutation vs. `set_()` calls vs. metadata-mutation
**This PR is still silently correct in one case though**, which I'd like to discuss more. In particular, this example:
```
def f(x):
x_view = x.view(-1)
x.set_(torch.ones(2))
x_view.mul_(2)
return
```
If you have an input that experiences both a data-mutation **and** a `x_old.set_(x_new)` call, there are two cases:
(a) the data mutation happened on the storage of `x_new`. This case should be handled automatically: if x_new is a graph intermediate then we will functionalize the mutation. If x_new is a different graph input, then we will perform the usual `copy_()` on that other graph input
(b) the data mutation happened on the storage of `x_old`. This is more of a pain to handle, and doesn't currently work. At runtime, the right thing to do is probably something like:
```
def functionalized_f(x):
x_view = x.view(-1)
# set_() desugars into a no-op; later usages of x will use x_output
x_output = torch.ones(2)
# functionalize the mutation on x_view
x_view_updated = x.mul(2)
x_updated = x_view_updated.view(x.shape)
# x experienced TWO TYPES of mutations; a data mutation and a metatadata mutation
# We need to return both updated tensors in our graph
return x_updated, x_output
def runtime_wrapper(x):
x_data_mutation_result, x_set_mutation_result = compiled_graph(x)
# First, perform the data mutation on x's old storage
x.copy_(x_data_mutation_result)
# Then, swap out the storage of x with the new storage
x.set_(x_set_mutation_result)
```
There are two things that make this difficult to do though:
(1) Functionalization: the functionalization rule for `set_()` will fully throw away the old `FunctionalStorageImpl` on the graph input. So if there are any mutations to that `FunctionalStorageImpl` later on in the graph, the current graph input won't know about it. Maybe we can have a given `FunctionalTensorWrapper` remember all previous storages that it had, and track mutations on all of them - although this feels pretty complicated.
(2) AOTAutograd now needs to know that we might have *two* graph outputs that correspond to a single "mutated input", which is annoying.
It's worth pointing out that this issue is probably extremely unlikely for anyone to run into - can we just detect it and error? This feels slightly easier than solving it, although not significantly easier. We would still need `FunctionalTensorWrapper` to keep track of mutations on any of its "previous" storages, so it can report this info back to AOTAutograd so we can raise an error.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111554
Approved by: https://github.com/ezyang
ghstack dependencies: #113926
This should be enough to get @voznesenskym 's FSDP branch to plumb `set_()` through AOTAutograd properly and have everything properly no-op out. Main changes are:
(1) graph break on `aten::set_.source_Tensor_storage_offset` (we could support it but it isn't needed, seems safer to graph break)
(2) Functionalization: add a "proper" functionalization kernel for `aten::set_.source_Tensor`. The previous one we had was codegen'd and it was wrong (it would just clone() and call set_(), which does not do the right thing). I also manually mark on the `FunctionalTensorWrapper` when a given tensor has been mutated by a `set_()` call.
(3) AOTAutograd: I added a new field, `InputAliasInfo.mutates_storage_metadata`, so we can distinguish between "regular" metadata mutations, and metadata mutations due to `set_()` calls. This is mainly because at runtime, one requires calling `as_strided_()` to fix up metadata, while the other requires calling `set_()`.
(4) Made AOTAutograd's detection for metadata mutations / set_() mutations smarter and detect no-ops (if the storage and metadata are all the same).
I also killed `was_updated()` and `was_metadata_updated()`, and replaced them with (existing) `has_data_mutation() ` and (new) `has_data_mutation()`, which can more accurately distinguish between data-mutation vs. `set_()` calls vs. metadata-mutation
**This PR is still silently correct in one case though**, which I'd like to discuss more. In particular, this example:
```
def f(x):
x_view = x.view(-1)
x.set_(torch.ones(2))
x_view.mul_(2)
return
```
If you have an input that experiences both a data-mutation **and** a `x_old.set_(x_new)` call, there are two cases:
(a) the data mutation happened on the storage of `x_new`. This case should be handled automatically: if x_new is a graph intermediate then we will functionalize the mutation. If x_new is a different graph input, then we will perform the usual `copy_()` on that other graph input
(b) the data mutation happened on the storage of `x_old`. This is more of a pain to handle, and doesn't currently work. At runtime, the right thing to do is probably something like:
```
def functionalized_f(x):
x_view = x.view(-1)
# set_() desugars into a no-op; later usages of x will use x_output
x_output = torch.ones(2)
# functionalize the mutation on x_view
x_view_updated = x.mul(2)
x_updated = x_view_updated.view(x.shape)
# x experienced TWO TYPES of mutations; a data mutation and a metatadata mutation
# We need to return both updated tensors in our graph
return x_updated, x_output
def runtime_wrapper(x):
x_data_mutation_result, x_set_mutation_result = compiled_graph(x)
# First, perform the data mutation on x's old storage
x.copy_(x_data_mutation_result)
# Then, swap out the storage of x with the new storage
x.set_(x_set_mutation_result)
```
There are two things that make this difficult to do though:
(1) Functionalization: the functionalization rule for `set_()` will fully throw away the old `FunctionalStorageImpl` on the graph input. So if there are any mutations to that `FunctionalStorageImpl` later on in the graph, the current graph input won't know about it. Maybe we can have a given `FunctionalTensorWrapper` remember all previous storages that it had, and track mutations on all of them - although this feels pretty complicated.
(2) AOTAutograd now needs to know that we might have *two* graph outputs that correspond to a single "mutated input", which is annoying.
It's worth pointing out that this issue is probably extremely unlikely for anyone to run into - can we just detect it and error? This feels slightly easier than solving it, although not significantly easier. We would still need `FunctionalTensorWrapper` to keep track of mutations on any of its "previous" storages, so it can report this info back to AOTAutograd so we can raise an error.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111554
Approved by: https://github.com/ezyang
We spend somewhere on the order 1% in `sympy.Expr.free_symbols` as it is called millions of times.
Most of the time we actually just want to know "is this a constant", however `e.is_constant()` is
horribly slow. It turns out though that there is another propery `is_number` that does what we want.
> property is_number:
>
> Returns True if self has no free symbols and no undefined functions (AppliedUndef, to be precise). It will be faster
> than if not self.free_symbols, however, since is_number will fail as soon as it hits a free symbol or undefined
> function.
Even further, we also avoid the overhead of building the unnecessary set object.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112688
Approved by: https://github.com/lezcano
Use conditional imports: when running under dynamo, import the original NumPy not torch._numpy. This is what we want to trace, not our implementation.
With this, the test suite passes with and without `PYTORCH_TEST_WITH_DYNAMO=1` (modulo a couple of test modules which are not meant to be compiled, e.g. `test_nep50_examples`). There are two new decorators, `x{fail,pass}ifTorchDynamo`, the `xpass` in most cases indicates a graph break and a fallback to eager for things we do not implement.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110401
Approved by: https://github.com/lezcano
AOTAutograd's handling for resize_() isn't fully robust (and on top of that, functionalization can potentially give up and raise an error if the tensor you're resizing has outstanding views).
So given that, and given that resize_() is rare, I updated dynamo to graph break on resize_() instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111553
Approved by: https://github.com/ezyang
Did some easy fixes from enabling TRY200. Most of these seem like oversights instead of intentional. The proper way to silence intentional errors is with `from None` to note that you thought about whether it should contain the cause and decided against it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111496
Approved by: https://github.com/malfet
Fixes part 1 of https://github.com/pytorch/pytorch/issues/111370#issuecomment-1764730773
While at it, add a test for numpy ndarray `.size` attribute. This started as an attempt to remove the delegation of what looks like a `.size()` method --- which does not exist in numpy --- on the same line this patch adds a `tolist` to.
But this is apparently needed for something else and existing tests start failing. Thus, declare it as _ain't broken don't fix_, and only keep the test. Can remove the test if wanted though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111382
Approved by: https://github.com/lezcano
Fixes#109604
Resubmit gh-109715 + several skips and small fixes to make tests pass.
The main fix here is by @ysiraichi : previously, dynamo did not resume tracing numpy ndarrays after a graph break.
While at it, fix several small issues Yukio's fix uncovers:
- graph break gracefully on numpy dtypes which do not map to torch.dtypes (uint16 etc)
- recognize array scalars in dynamo, treat them as 0D ndarrays
- make sure that iterating over torch.ndarray generates arrays not bare tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110512
Approved by: https://github.com/lezcano
The main thrust of the initial effort here was to capture `register_hook` calls on tensors in compile regions. The first part of this was done in https://github.com/pytorch/pytorch/pull/108903 wherein we added support for register_hook input tensors.
The distinction between input and intermediary is due to implementation differences.
There are 2 kinds of hooks:
1) Hooks on objects with sources (inputs, params)
2) Hooks on objects w/o sources (intermediaries, and outputs).
Note: As outputs can be made simple by how dynamo handles residuals, they could actually be handled as if they were inputs, but, for the sake of this PR, we will refer to hooks as either hooks on inputs (sourced), or hooks on intermediaries (not sourced).
**The plan:**
For tensors w/ a source: (The PR above)
We record registered hooks, store them as a global, and associate them with the tensor in residuals. This means that when dynamo goes to create the frame, where we produce bytecode to stitch together our PT2 modified bytecode with the original eager code, we call register_hook. This registration of hooks in residuals is sound because (a) it happens right after a Pt2 frame region ends and (b) we know that the tensor is alive in f_locals, f_globals, or a module in the users invoking frame. This means we can soundly know it will be around to invoke register_hook on. As long as we guard on the identity of the lifted function, this is sound to do.
For tensors w/o a source: (This PR)
Ostensibly, the most correct and complete solution would be to smuggle hooks into a runtime wrapper in aot_autograd, where all the items the hooks close over are lifted to inputs as necessary and passed alongside the user provided function. This is necessary so that we can properly trace out and capture all the mutations within the user defined hook at backwards time.
This is too complicated - so, we limited the scope of this initial PR to a simple subset of hooks:
- Hooks must have a source (be known to us already, not a lambda or intermediary defined function)
- We must be tracing under compiled autograd
**The flow**:
We use the HOP added in https://github.com/pytorch/pytorch/pull/109690/files, referred to as the HOP below.
1) We intercept register_hook calls and wrap the user defined fn in the HOP
2) We write a `_register_hook_trampoline` to the graph that is a local no-arg function that is invoked as a call_function in the dynamo graph
3) aot_autograd inlines through it during its trace, and sees the HOP
4) the HOP preserves itself in the graph - it does not get traced into
5) During backwards, compiled_autograd installs the HOP under a hook call
6) When compiled_autograd enters compilation over its generated graph, dynamo traces the contents of the hook
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109537
Approved by: https://github.com/ezyang
The strategy in this PR is pretty straightforward.
There are 2 kinds of hooks:
1) Hooks on objects with sources (inputs, params)
2) Hooks on objects w/o sources (intermediaries, and outputs).
Note: As outputs can be made simple by how dynamo handles residuals, they could actually be handled as if they were inputs, but, for the sake of this PR, we will refer to hooks as either hooks on inputs (sourced), or hooks on intermediaries (not sourced).
The plan:
**For tensors w/ a source:**
We record registered hooks, store them as a global, and associate them with the tensor in residuals. This means that when dynamo goes to create the frame, where we produce bytecode to stitch together our PT2 modified bytecode with the original eager code, we call `register_hook`. This registration of hooks in residuals is sound because (a) it happens right after a Pt2 frame region ends and (b) we know that the tensor is alive in f_locals, f_globals, or a module in the users invoking frame. This means we can soundly know it will be around to invoke `register_hook` on. As long as we guard on the identity of the lifted function, this is sound to do.
**For tensors w/o a source:**
Graph break - we will support this in a subsequent PR
**Handles:**
An interesting new component here is the creation of a `STORE_FAST `->`LOAD_FAST` associated with the handle, the return result of `register_hook`. If the user code stored the result of `register_hook` in a handle, we need to honor that. We do so by interceding into `STORE_FAST`, and recording the name of the local variable as directed by user code. We then honor that same name in the reconstructed bytecode. If the user did not store a hook, we merely pop the produced value to preserve the stack.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108903
Approved by: https://github.com/ezyang
ghstack dependencies: #108846, #109092
RFC: https://github.com/pytorch/rfcs/pull/54
First commit is the contents of https://github.com/Quansight-Labs/numpy_pytorch_interop/
We have already been using this in core for the last few months as a external dependency. This PR pulls all these into core.
In the next commits, I do a number of things in this order
- Fix a few small issues
- Make the tests that this PR adds pass
- Bend backwards until lintrunner passes
- Remove the optional dependency on `torch_np` and simply rely on the upstreamed code
- Fix a number dynamo tests that were passing before (they were not tasting anything I think) and are not passing now.
Missing from this PR (but not blocking):
- Have a flag that deactivates tracing NumPy functions and simply breaks. There used to be one but after the merge stopped working and I removed it. @lezcano to investigate.
- https://github.com/pytorch/pytorch/pull/106431#issuecomment-1667079543. @voznesenskym to submit a fix after we merge.
All the tests in `tests/torch_np` take about 75s to run.
This was a work by @ev-br, @rgommers @honno and I. I did not create this PR via ghstack (which would have been convenient) as this is a collaboration, and ghstack doesn't allow for shared contributions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106211
Approved by: https://github.com/ezyang
This PR adds initial dynamo support for DTensor, in particular, it:
- allows DTensor be passed into a compiled function, and allow fakify
DTensor during dynamo tracing by turning the inner local tensor to meta
tensor.
- We use `allow_in_graph` to include `DTensor` and `DTensor.from_local` to be represented as `TorchVariable`
- The dtensor created becomes a normal `TensorVariable` and it would insert any tensor operations to the output graph just like torch.Tensor
- note that dtensor have a new instance method `redistribute` compare to plain tensor, and we currently special handle it in `TensorVariable`
`from_local` and `redistribute` both accepts some non-trival metadata as arguments (i.e. DeviceMesh, Placement) which fx.Graph does not support. In order to let these two APIs appear in the dynamo captured graph, we encoded the metadata into a new_function (like `functools.partial`) and the new function only accepts prim args (i.e. tensor), then we put `call_function` with this new_function to the graph. This is suggested by @ezyang. The underlying rationale here is that the metadata will not change across the graph invocations so it's safe to encode them.
Captured graph:
```
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
# File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:685, code: dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
prim_from_local = torch__dynamo_variables_torch_prim_from_local(l_x_, run_check = False); l_x_ = None
# File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:686, code: return dt.redistribute(mesh, [Replicate()]).to_local() + 2
prim_redistribute = torch__dynamo_variables_tensor_prim_redistribute(prim_from_local); prim_from_local = None
to_local = prim_redistribute.to_local(); prim_redistribute = None
add = to_local + 2; to_local = None
return (add,)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103146
Approved by: https://github.com/voznesenskym