Commit Graph

175 Commits

Author SHA1 Message Date
Xuehai Pan
93e249969b [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261)
Remove useless parentheses in `raise` statements if the exception type is raised with no argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261
Approved by: https://github.com/albanD
2024-04-17 19:29:34 +00:00
Jason Ansel
f3fd280238 [dynamo] Relax strict_mode for autograd.Function forward inputs (#123910)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123910
Approved by: https://github.com/oulgen
2024-04-13 19:41:59 +00:00
Jason Ansel
70b8c58f84 [dynamo] Emit warning to turn on capture_scalar_outputs (#123896)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123896
Approved by: https://github.com/anijain2305
ghstack dependencies: #123700, #123705, #123786, #123790, #123803, #123804
2024-04-12 19:03:13 +00:00
Brian Hirsh
09be5800c8 dynamo: support placement kwargs for DTensor.to_local() (#119947)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119947
Approved by: https://github.com/wanchaol, https://github.com/yoyoyocmu
ghstack dependencies: #118803
2024-03-22 14:42:27 +00:00
Jason Ansel
477d154ffd [dynamo] Add missing _nonvar_fields annotations (#122219)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122219
Approved by: https://github.com/anijain2305
ghstack dependencies: #122218
2024-03-20 07:53:18 +00:00
Jason Ansel
153a01833b [dynamo] Optimize SourcelessBuilder (#122063)
Improves `benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py`
from 2.7s to 2.5s.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122063
Approved by: https://github.com/anijain2305
ghstack dependencies: #122039, #122043, #122055, #122058, #122060
2024-03-19 04:23:30 +00:00
Jason Ansel
8082adcf65 [dynamo] Only rename a proxy once (#122060)
Improves `benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py`
from 3.9s to 2.7s.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122060
Approved by: https://github.com/oulgen
ghstack dependencies: #122039, #122043, #122055, #122058
2024-03-19 04:23:27 +00:00
Jason Ansel
2bec55c5f9 [dynamo] Remove VariableTracker.parents_tracker (#122058)
This is leftover from mutable variable tracker days and no longer needed.

Improves benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py
from 4.2s to 3.9s.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122058
Approved by: https://github.com/oulgen, https://github.com/anijain2305
ghstack dependencies: #122039, #122043, #122055
2024-03-19 04:23:24 +00:00
Jason Ansel
769ff86b91 [dynamo] Optimize COMPARE_OP (#122039)
Improves `benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py`
from 5.6 to 5.1s.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122039
Approved by: https://github.com/Skylion007, https://github.com/anijain2305
2024-03-19 04:23:14 +00:00
Jason Ansel
5d52b163d1 [dynamo] Optimize load/store/const op handling (#122038)
Improves `benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py`
from 6.7s to 5.6.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122038
Approved by: https://github.com/Skylion007
ghstack dependencies: #122032, #122033, #122034, #122035
2024-03-18 18:08:06 +00:00
Jason Ansel
6ca0323615 [dynamo] Optimize VariableTracker.__post_init__ (#122034)
Improves `benchmarks/dynamo/microbenchmarks/dynamo_microbenchmarks.py`
from 8.6s to 7.3s.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122034
Approved by: https://github.com/Skylion007
ghstack dependencies: #122032, #122033
2024-03-18 18:08:06 +00:00
Jason Ansel
7cc476ea16 [dynamo] Fix support for nn.Parameter constructor (part 1) (#120163)
This captures calls to `torch.nn.Parameter` by lifting them to graph inputs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120163
Approved by: https://github.com/albanD, https://github.com/yanboliang
ghstack dependencies: #121086
2024-03-11 05:14:42 +00:00
Jason Ansel
32488b0664 [dynamo] Support _unsafe_set_version_counter (#121086)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121086
Approved by: https://github.com/yanboliang
2024-03-11 05:14:42 +00:00
Yukio Siraichi
aa0b0944d5 [dynamo] Re-dispatch torch.Tensor.new into torch.Tensor.new_empty method. (#121075)
Fix: https://github.com/pytorch/xla/issues/6009

This PR adds another case to `TensorVariable.method_new` special case, where it
re-dispatches `new` into `new_empty`.

Since we are using fake tensors, the `new` call doesn't actually gets to the corresponding
backend (e.g. XLA). So, things like the following might happen:

```python
@torch.compile(backend="openxla")
def foo(x):
    new_x = x.new(*x.size())

    # new_x.device() == "xla"
    # x.device() == "xla:0"

    return new_x + x

a = torch.arange(10)
foo(a.to(xm.xla_device()))
```

Resulting in the following error:

```python
Traceback (most recent call last):
  ...
  File "torch/_dynamo/utils.py", line 1654, in get_fake_value
    ret_val = wrap_fake_exception(
  File "torch/_dynamo/utils.py", line 1190, in wrap_fake_exception
    return fn()
  File "torch/_dynamo/utils.py", line 1655, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "torch/_dynamo/utils.py", line 1776, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "torch/_dynamo/utils.py", line 1758, in run_node
    return node.target(*args, **kwargs)
  File "torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "torch/_subclasses/fake_tensor.py", line 885, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "torch/_subclasses/fake_tensor.py", line 1224, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "torch/_subclasses/fake_tensor.py", line 955, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "torch/_subclasses/fake_tensor.py", line 1445, in _dispatch_impl
    return self.wrap_meta_outputs_with_default_device_logic(
  File "torch/_subclasses/fake_tensor.py", line 1575, in wrap_meta_outputs_with_default_device_logic
    return tree_map(wrap, r)
  File "torch/utils/_pytree.py", line 900, in tree_map
    return treespec.unflatten(map(func, *flat_args))
  File "torch/utils/_pytree.py", line 736, in unflatten
    leaves = list(leaves)
  File "torch/_subclasses/fake_tensor.py", line 1550, in wrap
    ) = FakeTensor._find_common_device(func, flat_args)
  File "torch/_subclasses/fake_tensor.py", line 625, in _find_common_device
    merge_devices(arg)
  File "torch/_subclasses/fake_tensor.py", line 620, in merge_devices
    raise RuntimeError(
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function add>(*(FakeTensor(..., device='xla', size=(10,), dtype=torch.int64), FakeTensor(..., device='xla:0', size=(10,), dtype=torch.int64)), **{}):
Unhandled FakeTensor Device Propagation for aten.add.Tensor, found two different devices xla, xla:0
```

Using `new_empty`, instead, fixes this error because it uses the device from the source
tensor, instead of inferring from the current dispatch key set.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121075
Approved by: https://github.com/jansel
2024-03-06 11:49:27 +00:00
Jason Ansel
4f19b5f7ef [dynamo] Remove extra guard for tensor constant attrs (#121106)
Also deletes some unused code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121106
Approved by: https://github.com/yanboliang, https://github.com/anijain2305
2024-03-05 17:16:04 +00:00
Pearu Peterson
ce2903080c Add sparse compressed fake tensor support (#120920)
As in the title.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120920
Approved by: https://github.com/ezyang
2024-03-04 14:38:45 +00:00
Jason Ansel
01ec8df6d8 [Compiled Autograd] Introduce BackwardState capture (#120382)
This adds support for backwards hooks that are *both*:
1) Interior to the graph; and
2) Dynamically generated (e.g. lambdas)

We do this by creating a BackwardState object that is used to register the hooks in the forward, then populated by dynamo *after* the forwards runs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120382
Approved by: https://github.com/xmfan
2024-02-28 20:36:47 +00:00
Animesh Jain
e3d64c4d5d [dynamo] Desugar accumulate_grad, fix .grad handling (#120590)
Fixes https://github.com/pytorch/pytorch/issues/118435
Fixes https://github.com/pytorch/pytorch/issues/119906

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120590
Approved by: https://github.com/ezyang, https://github.com/jansel
ghstack dependencies: #120520
2024-02-27 10:12:26 +00:00
Jason Ansel
2fea475215 [dynamo] Refactor reconstruct() not to return anything (#120150)
This simplifies things slightly and avoids some bugs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120150
Approved by: https://github.com/yanboliang
2024-02-17 17:13:41 +00:00
Brian Hirsh
26343451be DTensor: make tensor_flatten more compatible for dynamo getattr (#118209)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118209
Approved by: https://github.com/ezyang, https://github.com/wanchaol
ghstack dependencies: #117667, #117666
2024-02-16 21:16:07 +00:00
Brian Hirsh
ee7bcf23db dynamo: support attribute access on tensor subclasses without sources (#117666)
Fixes https://github.com/pytorch/pytorch/issues/117596

This was needed for Float8Tensor. Before this PR, dynamo would sometimes handle attribute access on tensor subclasses correctly, but it would choke on tensor subclasses with no source (it would fall back to using a `GetAttrVariable` to represent the attribute access, which is a problem if the attribute is a tensor that we later want to call tensor methods on).

I supported two cases:

(1) the attribute is a tensor, which is part of the `attrs` returned by the subclass's `__tensor_flatten__`. This creates a `TensorVariable`
(2) the attribute is a constant, which is part of the constant metadata returned by `__tensor_flatten__`. As per the contract of tensor_flatten, this should be a `ConstantVariable`. It could be possible that we allow non-constant metadata in the future, but we don't support that today.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117666
Approved by: https://github.com/zou3519
ghstack dependencies: #117667
2024-02-16 21:16:07 +00:00
Jason Ansel
75a6d6aef7 [inductor] Support storage resizing (#119749)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119749
Approved by: https://github.com/yf225
ghstack dependencies: #119647, #119671
2024-02-14 03:03:38 +00:00
Jason Ansel
39c68efd85 [dynamo] Capture untyped_storage().resize_() (#119647)
This makes storage resizing work with `backend=eager`, the next two PRs make it work for inductor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119647
Approved by: https://github.com/yf225
2024-02-13 19:03:28 +00:00
Jason Ansel
74d55b0e63 [dynamo] Support torch.distributed.fsdp._flat_param._same_storage_size (#119627)
Replaces #117690

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119627
Approved by: https://github.com/Skylion007
2024-02-13 01:27:37 +00:00
Oguz Ulgen
e693089c7a [Dynamo] Refactor tensor methods handling (#119581)
Fixes part of #119128

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119581
Approved by: https://github.com/jansel, https://github.com/anijain2305
2024-02-10 08:46:50 +00:00
Jason Ansel
e1c1b8c2b2 [dynamo] Improve support for backwards hooks (#119525)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119525
Approved by: https://github.com/yanboliang, https://github.com/anijain2305
2024-02-10 01:14:03 +00:00
PyTorch MergeBot
25a0fa6d13 Revert "[dynamo] Improve support for backwards hooks (#119525)"
This reverts commit b1f4b2a63c.

Reverted https://github.com/pytorch/pytorch/pull/119525 on behalf of https://github.com/clee2000 due to broke test_autograd.py::TestAutograd::test_post_accumulate_grad_hook_gets_cleaned_up on dynamo https://github.com/pytorch/pytorch/actions/runs/7847212828/job/21416215820 b1f4b2a63c.  The failure exists on the PR as well, but got masked by the other test.  Putting this as no signal? ([comment](https://github.com/pytorch/pytorch/pull/119525#issuecomment-1936447169))
2024-02-09 18:58:55 +00:00
Jason Ansel
b1f4b2a63c [dynamo] Improve support for backwards hooks (#119525)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119525
Approved by: https://github.com/yanboliang
2024-02-09 17:02:40 +00:00
Jason Ansel
62cc1053d8 [dynamo] Fix missing guards in FunctoolsPartialVariable (#118616)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118616
Approved by: https://github.com/yanboliang
ghstack dependencies: #118901
2024-02-06 23:42:43 +00:00
Jason Ansel
7a52455102 [dynamo] Refactor TensorVariable method handling (#119111)
This should slightly improve compile times and be easier to maintain.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119111
Approved by: https://github.com/yanboliang, https://github.com/anijain2305
2024-02-03 17:18:19 +00:00
ydwu4
9fe3693bbb [dynamo] bypass graph break due to masking if inference mode (#119056)
Relax the constraints in https://github.com/pytorch/pytorch/issues/114123 when we're in inference mode.

Test Plan:
See added tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119056
Approved by: https://github.com/ezyang, https://github.com/zou3519
2024-02-02 22:53:23 +00:00
Edward Z. Yang
d03173e88c Unify MYPYINDUCTOR and MYPY (#118432)
The original motivation for MYPYINDUCTOR was a faster type checking configuration that only checked a subset of files. With the removal of `follow_imports = ignore`, we are now able to use dmypy to do fast incremental typechecking, eliminating the need for this.

Perhaps erroneously, when I tee'ed up this PR I elected to delete the `follow_imports = skip` designations in the mypy-inductor.ini. This lead to a number of extra type error suppressions that I manually edited. You will need to review.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118432
Approved by: https://github.com/Skylion007
ghstack dependencies: #118414, #118418
2024-01-27 17:23:20 +00:00
ydwu4
fae569b4f2 [dynamo] avoid graph break on tensor.element_size() (#118229)
Before this PR, for the following code, we have a graph break `torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_method element_size`
```python
import torch
def f(x):
  return x.sin().element_size() + x.sin()

x = torch.randn(2, 2)
torch.compile(f, backend="eager", fullgraph=True)(x)
```
After this PR, we got the following graph, where element_size() is baked in as a constant.
```python
[2024-01-24 13:49:02,814] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]  <eval_with_key>.0 class GraphModule(torch.nn.Module):
[2024-01-24 13:49:02,814] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]     def forward(self, L_x_ : torch.Tensor):
[2024-01-24 13:49:02,814] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         l_x_ = L_x_
[2024-01-24 13:49:02,814] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]
[2024-01-24 13:49:02,814] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: /home/yidi/local/pytorch/test.py:4 in f, code: return x.sin().element_size() + x.sin()
[2024-01-24 13:49:02,814] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         sin = l_x_.sin()
[2024-01-24 13:49:02,814] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         sin_1 = l_x_.sin();  l_x_ = None
[2024-01-24 13:49:02,814] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         add = 4 + sin_1;  sin_1 = None
[2024-01-24 13:49:02,814] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         return (add,)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118229
Approved by: https://github.com/yanboliang, https://github.com/jansel, https://github.com/anijain2305
2024-01-25 22:28:37 +00:00
lezcano
4512a95371 [easy]Remove specialized value (#112252)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112252
Approved by: https://github.com/jansel
2024-01-18 09:34:50 +00:00
voznesenskym
83e8a0721d Reland #111196 (take 4) "Support tensors as Dict keys" (#116934)
Fixes #ISSUE_NUMBER

See that PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116934
Approved by: https://github.com/ezyang, https://github.com/huydhn
2024-01-07 01:37:26 +00:00
PyTorch MergeBot
2dca3e99eb Revert "Support tensors as Dict keys Re-PR of #111196 (#116785)"
This reverts commit 1badad9ce9.

Reverted https://github.com/pytorch/pytorch/pull/116785 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/116785#issuecomment-1879592261))
2024-01-06 08:22:33 +00:00
voznesenskym
1badad9ce9 Support tensors as Dict keys Re-PR of #111196 (#116785)
This prepares the PR where we implement sets in terms of dicts.
To do so, rather than storing internally a dictionary that maps literals
to VariableTrackers, it stores (pretty much) a dictionary from VTs to VTs.
To do so, keys are wrapped in an opaque internal class _Hashable.
The Hashable class is opaque on purpose so that it fails hard if
if it inadvertently leaks back into user code.
We also found and fixed a number of latent bugs and inconsistencies
in the way dynamo checked what can be a dict key. More generally, we
make much clearer what are the things that need to be modified to add
a new supported key type to Dicts.

Fixes [#107595](https://www.internalfb.com/tasks?t=107595)
Fixes [#111603](https://www.internalfb.com/tasks?t=111603)

Re-PR of https://github.com/pytorch/pytorch/pull/111196 sadly due to reverts, we could not reuse @lezcano's original PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116785
Approved by: https://github.com/mlazos
2024-01-06 03:35:35 +00:00
Yanbo Liang
f657b2b1f8 [Dynamo][10/N] Remove TorchVariable and is_allowed (#116312)
After this refactor:
* ```TorchVariable``` definition and all references are removed.
* All ```is_allowed``` references except one are removed.
  - The only left one is in ```torch/_dynamo/decorators:_disallow_in_graph_helper```. It was called when users put ```disallow_in_graph``` decorator on a function. Since we use the lists in ```trace_rules``` to decide the function's trace rule, so the decorator would only be used as customer function rather than torch functions. I'll defer this to a separate decorator refactor PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116312
Approved by: https://github.com/jansel
2023-12-27 18:47:05 +00:00
PyTorch MergeBot
3b709d7c1e Revert "[Dynamo][10/N] Remove TorchVariable and is_allowed (#116312)"
This reverts commit 015bd0e0a1.

Reverted https://github.com/pytorch/pytorch/pull/116312 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/116312#issuecomment-1869825506))
2023-12-26 23:47:15 +00:00
Yanbo Liang
015bd0e0a1 [Dynamo][10/N] Remove TorchVariable and is_allowed (#116312)
After this refactor:
* ```TorchVariable``` definition and all references are removed.
* All ```is_allowed``` references except one are removed.
  - The only left one is in ```torch/_dynamo/decorators:_disallow_in_graph_helper```. It was called when users put ```disallow_in_graph``` decorator on a function. Since we use the lists in ```trace_rules``` to decide the function's trace rule, so the decorator would only be used as customer function rather than torch functions. I'll defer this to a separate decorator refactor PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116312
Approved by: https://github.com/jansel
2023-12-23 09:44:09 +00:00
Guilherme Leobas
1be6a070bc Add support for torch.cond in vmap (#114523)
Fixes: https://github.com/pytorch/pytorch/issues/114136

Patch enables conversion of a BatchedTensor into FakeTensor and write
torch.cond vmap support using torch.where

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114523
Approved by: https://github.com/zou3519
2023-12-20 19:54:38 +00:00
Michael Lazos
8eb7f6276b Ensure wrapping subclasses with as_subclass is supported (#116091)
As title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116091
Approved by: https://github.com/pmeier, https://github.com/zou3519
2023-12-20 14:37:08 +00:00
Michael Lazos
fbeca60b1f Remove replace_all and make VTs mutable (#113725)
1.  Removes calls to `replace_all` and `clone` and makes VTs mutable.
2. Properly handles Tuple Iterator mutation. Previously TupleIterator variables would only be properly reconstructed if they were advanced at least once in a frame. On calls to `next`, the source information would be lost (due to constructing a new iterator without using builder), which would ensure that during codegen the variable would be reconstructed from scratch. Now that VTs are mutated, the source is never lost, so we need to properly track mutation and handle it by replaying calls to `next` at the end of the modified bytecode.
3. Added test for checking iadd side effects, this was missing in our unit test coverage.
4. Fixed two incorrect sources, DelayGraphBreakVariable, and UserMethodVariable both relied on setting the source to AttrSource(parent, name) at the callsite of `var_getattr`.
5. Fixed a bug in inplace adding for lists, it would set the resulting VariableTracker's source to `None` which would utilize a different reconstruct path in codegen. Now this is handled explicitly by reconstructing vars when allow_cache=`False`, so that during side effect replay, the mutated var is correctly updated.

In subsequent PRs:
* Refactoring side effect tracking to be significantly simpler (I think we only need an `is_modified` flag)
* Refactor `next_variables` iterator to match the signature of `next`
* Remove all references to `options` in the code
* Refactor VTs representing mutable collections to implement their own mutation update handling
* Remove clone and/or make it specific to lists for creating slices
* Add mutation tracking/replay for sets
* Add mutation tracking/replay for iter.py
* Removing setting source in builder (it's set at the top level after a var is returned)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113725
Approved by: https://github.com/jansel
2023-12-10 09:31:21 +00:00
Michael Lazos
3c882925da Make subclass type instances constants (like UserDefinedClasses) (#115323)
As title

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115323
Approved by: https://github.com/oulgen
2023-12-07 08:10:59 +00:00
Jason Ansel
522bae20df [dynamo] Support any() on SymNodeVariable (#115119)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115119
Approved by: https://github.com/yanboliang
ghstack dependencies: #115095, #115046, #115057
2023-12-05 19:01:31 +00:00
voznesenskym
ddf1cb7870 AOTAutograd: handle set_(), detect metadata mutations that cancel out (#111554)
This should be enough to get @voznesenskym 's FSDP branch to plumb `set_()` through AOTAutograd properly and have everything properly no-op out. Main changes are:

(1) graph break on `aten::set_.source_Tensor_storage_offset` (we could support it but it isn't needed, seems safer to graph break)

(2) Functionalization: add a "proper" functionalization kernel for `aten::set_.source_Tensor`. The previous one we had was codegen'd and it was wrong (it would just clone() and call set_(), which does not do the right thing). I also manually mark on the `FunctionalTensorWrapper` when a given tensor has been mutated by a `set_()` call.

(3) AOTAutograd: I added a new field, `InputAliasInfo.mutates_storage_metadata`, so we can distinguish between "regular" metadata mutations, and metadata mutations due to `set_()` calls. This is mainly because at runtime, one requires calling `as_strided_()` to fix up metadata, while the other requires calling `set_()`.

(4) Made AOTAutograd's detection for metadata mutations / set_() mutations smarter and detect no-ops (if the storage and metadata are all the same).

I also killed `was_updated()` and `was_metadata_updated()`, and replaced them with (existing) `has_data_mutation() ` and (new) `has_data_mutation()`, which can more accurately distinguish between data-mutation vs. `set_()` calls vs. metadata-mutation

**This PR is still silently correct in one case though**, which I'd like to discuss more. In particular, this example:
```
def f(x):
    x_view = x.view(-1)
    x.set_(torch.ones(2))
    x_view.mul_(2)
    return
```

If you have an input that experiences both a data-mutation **and** a `x_old.set_(x_new)` call, there are two cases:

(a) the data mutation happened on the storage of `x_new`. This case should be handled automatically: if x_new is a graph intermediate then we will functionalize the mutation. If x_new is a different graph input, then we will perform the usual `copy_()` on that other graph input

(b) the data mutation happened on the storage of `x_old`. This is more of a pain to handle, and doesn't currently work. At runtime, the right thing to do is probably something like:
```

def functionalized_f(x):
    x_view = x.view(-1)
    # set_() desugars into a no-op; later usages of x will use x_output
    x_output = torch.ones(2)
    # functionalize the mutation on x_view
    x_view_updated = x.mul(2)
    x_updated = x_view_updated.view(x.shape)
    # x experienced TWO TYPES of mutations; a data mutation and a metatadata mutation
    # We need to return both updated tensors in our graph
    return x_updated, x_output
def runtime_wrapper(x):
    x_data_mutation_result, x_set_mutation_result = compiled_graph(x)
    # First, perform the data mutation on x's old storage
    x.copy_(x_data_mutation_result)
    # Then, swap out the storage of x with the new storage
    x.set_(x_set_mutation_result)
```

There are two things that make this difficult to do though:

(1) Functionalization: the functionalization rule for `set_()` will fully throw away the old `FunctionalStorageImpl` on the graph input. So if there are any mutations to that `FunctionalStorageImpl` later on in the graph, the current graph input won't know about it. Maybe we can have a given `FunctionalTensorWrapper` remember all previous storages that it had, and track mutations on all of them - although this feels pretty complicated.

(2) AOTAutograd now needs to know that we might have *two* graph outputs that correspond to a single "mutated input", which is annoying.

It's worth pointing out that this issue is probably extremely unlikely for anyone to run into - can we just detect it and error? This feels slightly easier than solving it, although not significantly easier. We would still need `FunctionalTensorWrapper` to keep track of mutations on any of its "previous" storages, so it can report this info back to AOTAutograd so we can raise an error.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111554
Approved by: https://github.com/ezyang
ghstack dependencies: #113926
2023-11-28 19:33:35 +00:00
lezcano
0bb2600c28 Allow to differentiate through NumPy code (#114608)
With this PR it is possible to differentiate through NumPy code modulo
the usual caveats that apply to differentiation:
- That there are no graphbreaks
- That the decomposition in `torch._numpy` is differentiable

@ev-br and I were somewhat careful to achieve the second point, but
it is not tested though and through, so YMMV

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114608
Approved by: https://github.com/voznesenskym
2023-11-28 12:04:37 +00:00
PyTorch MergeBot
3e1abde46d Revert "AOTAutograd: handle set_(), detect metadata mutations that cancel out (#111554)"
This reverts commit a911b4db9d.

Reverted https://github.com/pytorch/pytorch/pull/111554 on behalf of https://github.com/DanilBaibak due to The lower PR in the stack #113926 breaks the internal build ([comment](https://github.com/pytorch/pytorch/pull/111554#issuecomment-1822472206))
2023-11-22 10:13:48 +00:00
Jon Chuang
f66add9b85 [dynamo] graph break on np.ndarray.tobytes (#114208)
We can't model this accurately across np and tnp https://github.com/pytorch/pytorch/issues/114204#issuecomment-1820269949

So let's not even try. Just graph break.

Fixes: https://github.com/pytorch/pytorch/issues/114204

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114208
Approved by: https://github.com/lezcano
2023-11-21 18:19:37 +00:00
voznesenskym
a911b4db9d AOTAutograd: handle set_(), detect metadata mutations that cancel out (#111554)
This should be enough to get @voznesenskym 's FSDP branch to plumb `set_()` through AOTAutograd properly and have everything properly no-op out. Main changes are:

(1) graph break on `aten::set_.source_Tensor_storage_offset` (we could support it but it isn't needed, seems safer to graph break)

(2) Functionalization: add a "proper" functionalization kernel for `aten::set_.source_Tensor`. The previous one we had was codegen'd and it was wrong (it would just clone() and call set_(), which does not do the right thing). I also manually mark on the `FunctionalTensorWrapper` when a given tensor has been mutated by a `set_()` call.

(3) AOTAutograd: I added a new field, `InputAliasInfo.mutates_storage_metadata`, so we can distinguish between "regular" metadata mutations, and metadata mutations due to `set_()` calls. This is mainly because at runtime, one requires calling `as_strided_()` to fix up metadata, while the other requires calling `set_()`.

(4) Made AOTAutograd's detection for metadata mutations / set_() mutations smarter and detect no-ops (if the storage and metadata are all the same).

I also killed `was_updated()` and `was_metadata_updated()`, and replaced them with (existing) `has_data_mutation() ` and (new) `has_data_mutation()`, which can more accurately distinguish between data-mutation vs. `set_()` calls vs. metadata-mutation

**This PR is still silently correct in one case though**, which I'd like to discuss more. In particular, this example:
```
def f(x):
    x_view = x.view(-1)
    x.set_(torch.ones(2))
    x_view.mul_(2)
    return
```

If you have an input that experiences both a data-mutation **and** a `x_old.set_(x_new)` call, there are two cases:

(a) the data mutation happened on the storage of `x_new`. This case should be handled automatically: if x_new is a graph intermediate then we will functionalize the mutation. If x_new is a different graph input, then we will perform the usual `copy_()` on that other graph input

(b) the data mutation happened on the storage of `x_old`. This is more of a pain to handle, and doesn't currently work. At runtime, the right thing to do is probably something like:
```

def functionalized_f(x):
    x_view = x.view(-1)
    # set_() desugars into a no-op; later usages of x will use x_output
    x_output = torch.ones(2)
    # functionalize the mutation on x_view
    x_view_updated = x.mul(2)
    x_updated = x_view_updated.view(x.shape)
    # x experienced TWO TYPES of mutations; a data mutation and a metatadata mutation
    # We need to return both updated tensors in our graph
    return x_updated, x_output
def runtime_wrapper(x):
    x_data_mutation_result, x_set_mutation_result = compiled_graph(x)
    # First, perform the data mutation on x's old storage
    x.copy_(x_data_mutation_result)
    # Then, swap out the storage of x with the new storage
    x.set_(x_set_mutation_result)
```

There are two things that make this difficult to do though:

(1) Functionalization: the functionalization rule for `set_()` will fully throw away the old `FunctionalStorageImpl` on the graph input. So if there are any mutations to that `FunctionalStorageImpl` later on in the graph, the current graph input won't know about it. Maybe we can have a given `FunctionalTensorWrapper` remember all previous storages that it had, and track mutations on all of them - although this feels pretty complicated.

(2) AOTAutograd now needs to know that we might have *two* graph outputs that correspond to a single "mutated input", which is annoying.

It's worth pointing out that this issue is probably extremely unlikely for anyone to run into - can we just detect it and error? This feels slightly easier than solving it, although not significantly easier. We would still need `FunctionalTensorWrapper` to keep track of mutations on any of its "previous" storages, so it can report this info back to AOTAutograd so we can raise an error.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111554
Approved by: https://github.com/ezyang
ghstack dependencies: #113926
2023-11-21 01:52:46 +00:00