Fix: https://github.com/pytorch/xla/issues/6009
This PR adds another case to `TensorVariable.method_new` special case, where it
re-dispatches `new` into `new_empty`.
Since we are using fake tensors, the `new` call doesn't actually gets to the corresponding
backend (e.g. XLA). So, things like the following might happen:
```python
@torch.compile(backend="openxla")
def foo(x):
new_x = x.new(*x.size())
# new_x.device() == "xla"
# x.device() == "xla:0"
return new_x + x
a = torch.arange(10)
foo(a.to(xm.xla_device()))
```
Resulting in the following error:
```python
Traceback (most recent call last):
...
File "torch/_dynamo/utils.py", line 1654, in get_fake_value
ret_val = wrap_fake_exception(
File "torch/_dynamo/utils.py", line 1190, in wrap_fake_exception
return fn()
File "torch/_dynamo/utils.py", line 1655, in <lambda>
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
File "torch/_dynamo/utils.py", line 1776, in run_node
raise RuntimeError(make_error_message(e)).with_traceback(
File "torch/_dynamo/utils.py", line 1758, in run_node
return node.target(*args, **kwargs)
File "torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "torch/_subclasses/fake_tensor.py", line 885, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "torch/_subclasses/fake_tensor.py", line 1224, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
File "torch/_subclasses/fake_tensor.py", line 955, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
File "torch/_subclasses/fake_tensor.py", line 1445, in _dispatch_impl
return self.wrap_meta_outputs_with_default_device_logic(
File "torch/_subclasses/fake_tensor.py", line 1575, in wrap_meta_outputs_with_default_device_logic
return tree_map(wrap, r)
File "torch/utils/_pytree.py", line 900, in tree_map
return treespec.unflatten(map(func, *flat_args))
File "torch/utils/_pytree.py", line 736, in unflatten
leaves = list(leaves)
File "torch/_subclasses/fake_tensor.py", line 1550, in wrap
) = FakeTensor._find_common_device(func, flat_args)
File "torch/_subclasses/fake_tensor.py", line 625, in _find_common_device
merge_devices(arg)
File "torch/_subclasses/fake_tensor.py", line 620, in merge_devices
raise RuntimeError(
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function add>(*(FakeTensor(..., device='xla', size=(10,), dtype=torch.int64), FakeTensor(..., device='xla:0', size=(10,), dtype=torch.int64)), **{}):
Unhandled FakeTensor Device Propagation for aten.add.Tensor, found two different devices xla, xla:0
```
Using `new_empty`, instead, fixes this error because it uses the device from the source
tensor, instead of inferring from the current dispatch key set.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121075
Approved by: https://github.com/jansel
This adds support for backwards hooks that are *both*:
1) Interior to the graph; and
2) Dynamically generated (e.g. lambdas)
We do this by creating a BackwardState object that is used to register the hooks in the forward, then populated by dynamo *after* the forwards runs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120382
Approved by: https://github.com/xmfan
Fixes https://github.com/pytorch/pytorch/issues/117596
This was needed for Float8Tensor. Before this PR, dynamo would sometimes handle attribute access on tensor subclasses correctly, but it would choke on tensor subclasses with no source (it would fall back to using a `GetAttrVariable` to represent the attribute access, which is a problem if the attribute is a tensor that we later want to call tensor methods on).
I supported two cases:
(1) the attribute is a tensor, which is part of the `attrs` returned by the subclass's `__tensor_flatten__`. This creates a `TensorVariable`
(2) the attribute is a constant, which is part of the constant metadata returned by `__tensor_flatten__`. As per the contract of tensor_flatten, this should be a `ConstantVariable`. It could be possible that we allow non-constant metadata in the future, but we don't support that today.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117666
Approved by: https://github.com/zou3519
ghstack dependencies: #117667
The original motivation for MYPYINDUCTOR was a faster type checking configuration that only checked a subset of files. With the removal of `follow_imports = ignore`, we are now able to use dmypy to do fast incremental typechecking, eliminating the need for this.
Perhaps erroneously, when I tee'ed up this PR I elected to delete the `follow_imports = skip` designations in the mypy-inductor.ini. This lead to a number of extra type error suppressions that I manually edited. You will need to review.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118432
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418
This prepares the PR where we implement sets in terms of dicts.
To do so, rather than storing internally a dictionary that maps literals
to VariableTrackers, it stores (pretty much) a dictionary from VTs to VTs.
To do so, keys are wrapped in an opaque internal class _Hashable.
The Hashable class is opaque on purpose so that it fails hard if
if it inadvertently leaks back into user code.
We also found and fixed a number of latent bugs and inconsistencies
in the way dynamo checked what can be a dict key. More generally, we
make much clearer what are the things that need to be modified to add
a new supported key type to Dicts.
Fixes [#107595](https://www.internalfb.com/tasks?t=107595)
Fixes [#111603](https://www.internalfb.com/tasks?t=111603)
Re-PR of https://github.com/pytorch/pytorch/pull/111196 sadly due to reverts, we could not reuse @lezcano's original PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116785
Approved by: https://github.com/mlazos
After this refactor:
* ```TorchVariable``` definition and all references are removed.
* All ```is_allowed``` references except one are removed.
- The only left one is in ```torch/_dynamo/decorators:_disallow_in_graph_helper```. It was called when users put ```disallow_in_graph``` decorator on a function. Since we use the lists in ```trace_rules``` to decide the function's trace rule, so the decorator would only be used as customer function rather than torch functions. I'll defer this to a separate decorator refactor PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116312
Approved by: https://github.com/jansel
After this refactor:
* ```TorchVariable``` definition and all references are removed.
* All ```is_allowed``` references except one are removed.
- The only left one is in ```torch/_dynamo/decorators:_disallow_in_graph_helper```. It was called when users put ```disallow_in_graph``` decorator on a function. Since we use the lists in ```trace_rules``` to decide the function's trace rule, so the decorator would only be used as customer function rather than torch functions. I'll defer this to a separate decorator refactor PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116312
Approved by: https://github.com/jansel
1. Removes calls to `replace_all` and `clone` and makes VTs mutable.
2. Properly handles Tuple Iterator mutation. Previously TupleIterator variables would only be properly reconstructed if they were advanced at least once in a frame. On calls to `next`, the source information would be lost (due to constructing a new iterator without using builder), which would ensure that during codegen the variable would be reconstructed from scratch. Now that VTs are mutated, the source is never lost, so we need to properly track mutation and handle it by replaying calls to `next` at the end of the modified bytecode.
3. Added test for checking iadd side effects, this was missing in our unit test coverage.
4. Fixed two incorrect sources, DelayGraphBreakVariable, and UserMethodVariable both relied on setting the source to AttrSource(parent, name) at the callsite of `var_getattr`.
5. Fixed a bug in inplace adding for lists, it would set the resulting VariableTracker's source to `None` which would utilize a different reconstruct path in codegen. Now this is handled explicitly by reconstructing vars when allow_cache=`False`, so that during side effect replay, the mutated var is correctly updated.
In subsequent PRs:
* Refactoring side effect tracking to be significantly simpler (I think we only need an `is_modified` flag)
* Refactor `next_variables` iterator to match the signature of `next`
* Remove all references to `options` in the code
* Refactor VTs representing mutable collections to implement their own mutation update handling
* Remove clone and/or make it specific to lists for creating slices
* Add mutation tracking/replay for sets
* Add mutation tracking/replay for iter.py
* Removing setting source in builder (it's set at the top level after a var is returned)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113725
Approved by: https://github.com/jansel
This should be enough to get @voznesenskym 's FSDP branch to plumb `set_()` through AOTAutograd properly and have everything properly no-op out. Main changes are:
(1) graph break on `aten::set_.source_Tensor_storage_offset` (we could support it but it isn't needed, seems safer to graph break)
(2) Functionalization: add a "proper" functionalization kernel for `aten::set_.source_Tensor`. The previous one we had was codegen'd and it was wrong (it would just clone() and call set_(), which does not do the right thing). I also manually mark on the `FunctionalTensorWrapper` when a given tensor has been mutated by a `set_()` call.
(3) AOTAutograd: I added a new field, `InputAliasInfo.mutates_storage_metadata`, so we can distinguish between "regular" metadata mutations, and metadata mutations due to `set_()` calls. This is mainly because at runtime, one requires calling `as_strided_()` to fix up metadata, while the other requires calling `set_()`.
(4) Made AOTAutograd's detection for metadata mutations / set_() mutations smarter and detect no-ops (if the storage and metadata are all the same).
I also killed `was_updated()` and `was_metadata_updated()`, and replaced them with (existing) `has_data_mutation() ` and (new) `has_data_mutation()`, which can more accurately distinguish between data-mutation vs. `set_()` calls vs. metadata-mutation
**This PR is still silently correct in one case though**, which I'd like to discuss more. In particular, this example:
```
def f(x):
x_view = x.view(-1)
x.set_(torch.ones(2))
x_view.mul_(2)
return
```
If you have an input that experiences both a data-mutation **and** a `x_old.set_(x_new)` call, there are two cases:
(a) the data mutation happened on the storage of `x_new`. This case should be handled automatically: if x_new is a graph intermediate then we will functionalize the mutation. If x_new is a different graph input, then we will perform the usual `copy_()` on that other graph input
(b) the data mutation happened on the storage of `x_old`. This is more of a pain to handle, and doesn't currently work. At runtime, the right thing to do is probably something like:
```
def functionalized_f(x):
x_view = x.view(-1)
# set_() desugars into a no-op; later usages of x will use x_output
x_output = torch.ones(2)
# functionalize the mutation on x_view
x_view_updated = x.mul(2)
x_updated = x_view_updated.view(x.shape)
# x experienced TWO TYPES of mutations; a data mutation and a metatadata mutation
# We need to return both updated tensors in our graph
return x_updated, x_output
def runtime_wrapper(x):
x_data_mutation_result, x_set_mutation_result = compiled_graph(x)
# First, perform the data mutation on x's old storage
x.copy_(x_data_mutation_result)
# Then, swap out the storage of x with the new storage
x.set_(x_set_mutation_result)
```
There are two things that make this difficult to do though:
(1) Functionalization: the functionalization rule for `set_()` will fully throw away the old `FunctionalStorageImpl` on the graph input. So if there are any mutations to that `FunctionalStorageImpl` later on in the graph, the current graph input won't know about it. Maybe we can have a given `FunctionalTensorWrapper` remember all previous storages that it had, and track mutations on all of them - although this feels pretty complicated.
(2) AOTAutograd now needs to know that we might have *two* graph outputs that correspond to a single "mutated input", which is annoying.
It's worth pointing out that this issue is probably extremely unlikely for anyone to run into - can we just detect it and error? This feels slightly easier than solving it, although not significantly easier. We would still need `FunctionalTensorWrapper` to keep track of mutations on any of its "previous" storages, so it can report this info back to AOTAutograd so we can raise an error.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111554
Approved by: https://github.com/ezyang
ghstack dependencies: #113926
With this PR it is possible to differentiate through NumPy code modulo
the usual caveats that apply to differentiation:
- That there are no graphbreaks
- That the decomposition in `torch._numpy` is differentiable
@ev-br and I were somewhat careful to achieve the second point, but
it is not tested though and through, so YMMV
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114608
Approved by: https://github.com/voznesenskym
This should be enough to get @voznesenskym 's FSDP branch to plumb `set_()` through AOTAutograd properly and have everything properly no-op out. Main changes are:
(1) graph break on `aten::set_.source_Tensor_storage_offset` (we could support it but it isn't needed, seems safer to graph break)
(2) Functionalization: add a "proper" functionalization kernel for `aten::set_.source_Tensor`. The previous one we had was codegen'd and it was wrong (it would just clone() and call set_(), which does not do the right thing). I also manually mark on the `FunctionalTensorWrapper` when a given tensor has been mutated by a `set_()` call.
(3) AOTAutograd: I added a new field, `InputAliasInfo.mutates_storage_metadata`, so we can distinguish between "regular" metadata mutations, and metadata mutations due to `set_()` calls. This is mainly because at runtime, one requires calling `as_strided_()` to fix up metadata, while the other requires calling `set_()`.
(4) Made AOTAutograd's detection for metadata mutations / set_() mutations smarter and detect no-ops (if the storage and metadata are all the same).
I also killed `was_updated()` and `was_metadata_updated()`, and replaced them with (existing) `has_data_mutation() ` and (new) `has_data_mutation()`, which can more accurately distinguish between data-mutation vs. `set_()` calls vs. metadata-mutation
**This PR is still silently correct in one case though**, which I'd like to discuss more. In particular, this example:
```
def f(x):
x_view = x.view(-1)
x.set_(torch.ones(2))
x_view.mul_(2)
return
```
If you have an input that experiences both a data-mutation **and** a `x_old.set_(x_new)` call, there are two cases:
(a) the data mutation happened on the storage of `x_new`. This case should be handled automatically: if x_new is a graph intermediate then we will functionalize the mutation. If x_new is a different graph input, then we will perform the usual `copy_()` on that other graph input
(b) the data mutation happened on the storage of `x_old`. This is more of a pain to handle, and doesn't currently work. At runtime, the right thing to do is probably something like:
```
def functionalized_f(x):
x_view = x.view(-1)
# set_() desugars into a no-op; later usages of x will use x_output
x_output = torch.ones(2)
# functionalize the mutation on x_view
x_view_updated = x.mul(2)
x_updated = x_view_updated.view(x.shape)
# x experienced TWO TYPES of mutations; a data mutation and a metatadata mutation
# We need to return both updated tensors in our graph
return x_updated, x_output
def runtime_wrapper(x):
x_data_mutation_result, x_set_mutation_result = compiled_graph(x)
# First, perform the data mutation on x's old storage
x.copy_(x_data_mutation_result)
# Then, swap out the storage of x with the new storage
x.set_(x_set_mutation_result)
```
There are two things that make this difficult to do though:
(1) Functionalization: the functionalization rule for `set_()` will fully throw away the old `FunctionalStorageImpl` on the graph input. So if there are any mutations to that `FunctionalStorageImpl` later on in the graph, the current graph input won't know about it. Maybe we can have a given `FunctionalTensorWrapper` remember all previous storages that it had, and track mutations on all of them - although this feels pretty complicated.
(2) AOTAutograd now needs to know that we might have *two* graph outputs that correspond to a single "mutated input", which is annoying.
It's worth pointing out that this issue is probably extremely unlikely for anyone to run into - can we just detect it and error? This feels slightly easier than solving it, although not significantly easier. We would still need `FunctionalTensorWrapper` to keep track of mutations on any of its "previous" storages, so it can report this info back to AOTAutograd so we can raise an error.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111554
Approved by: https://github.com/ezyang
ghstack dependencies: #113926