Commit Graph

316 Commits

Author SHA1 Message Date
Edward Z. Yang
f19e07b056 Memoize local_scalar_dense calls, refactor all memos (#125623)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125623
Approved by: https://github.com/eellison
2024-05-11 21:12:35 +00:00
PyTorch MergeBot
c6e5d0d2e6 Revert "Memoize local_scalar_dense calls, refactor all memos (#125623)"
This reverts commit fcbf2b61e6.

Reverted https://github.com/pytorch/pytorch/pull/125623 on behalf of https://github.com/malfet due to Broke ROCM, see https://github.com/pytorch/pytorch/actions/runs/9026074378/job/24804583041 ([comment](https://github.com/pytorch/pytorch/pull/125623#issuecomment-2105444091))
2024-05-11 01:58:39 +00:00
Brian Hirsh
f25c7c9699 functionalize storage resizing, minimal ppFSDP traceable forward (#122434)
More details further down, but first a more high-level description of "how do we functionalize storage resizing"

Today, dynamo converts `param.untyped_storage().resize_(x)` calls that it sees from fsdp into a custom op, `ops.inductor.resize_storage_bytes_(x)`

So given this setup, there are 3 main cases that I think we want to handle:

(1) graph input starts with a real storage size, gets resized down to zero in the graph
(2) graph input starts with 0 storage size, gets resized up in the graph
(3) graph input starts with 0 storage size, gets resized up and used in some compute, then resized back down to 0

For case (1) we need to emit a `resize_storage_bytes_` at the end of the graph, similar to how we emit `copy_()` for data mutations.

For case (2), we need to emit a `resize_storage_bytes_` in the graph, and we **also** need to emit a `copy_()` (the input had its storage resized up, and filled in with data, which is we need to reflect as an input mutation)

For case (3), the net effect is that the input had no data on entry and exit of the function, so we don't need to emit any mutable ops in the end of the graph.

The main thing to call out is that: we need to write a functionalization rule for `resize_storage_byte_`, (`FunctionalTensorWrapper::storage_resize_()`) and this rule actually does very little. We would like to **not** emit any new ops in the graph (like say, a functional resize op). Instead, we should expect / rely on the fact that any resize up will be immediately followed by a `copy_()`/`foreach_copy_`/`out=` op, that will fill in the data of the tensor. So `FunctionalTensor` can temporarily live in a state where its data is invalid, until the `x.copy_(y)` "updates" its data with the new tensor.

So effectively, all that this rule does is:

(1) it stores metadata on the storage, indicating that the tensor was resized, as well as the updated storage size. We need this info in AOTAutograd, so it knows whether to emit a mutable resize_() op in the graph epilogue

(2) There is also a corner case: if we are resizing down to zero, but our tensor had **previously** had a zero size storage, then we update `value_` to point to the original value of the tensor. The reason this seems safe is because if we have a zero storage sized tensor `x`, and we resize it up, use it in some compute, resize it back down to zero, and use it somewhere, we would want the functional version of this code to use the original `x` after the second resize. For FSDP, this is important because we end up saving parameters (graph inputs) for backward, and we want to make sure that the thing we save (and the output to the forward graph) is the original, zero-storage-sized parameter, and not the "version 2" of the parameter after the first resize_()

I think a good order to look at changes in this PR would be:

(1) `test_aotdispatch.py` shows the 3 main cases I focused on as well as the expected functionalized graphs

(2) In `FunctionalStorageImpl.h/cpp`, I had to add a notion of "original base", and "original/curr_size". The first is so I can re-use the zero-size tensor after multiple resizes, and the second is so I can tell in AOTAutograd whether any resizes canceled each other out into a no-op

(3) FunctionalTensorWrapper.h/cpp has the new resize functionalizion rule + some extra utils

(4) `_functorch/_autograd`: the main changes in this folder were around adding the logic at trace-time to detect when we need to put a resize_() in the graph. I also have some assertions to check that any inputs that experience storage resizing will **always be in the graph** and not the opaque epilogue, and I also limited the resize_() mutation case so that you can only ever start with zero storage, or end with zero storage (you can't do e.g. `torch.ones(2).storage().resize_(3)`), and banned it on tensor subclasses

(5) `fake_tensor.py`/`meta_utils.py`: we now need to be able to fakeify tensors with zero storage, so I added a quick version of it in meta_utils.py. This also.. has ramifications for fake tensor caching that I need to fix (include the storage size on the cache key, maybe?)

------------------

This PR subsumes https://github.com/pytorch/pytorch/pull/120971.

This PR is enough to **almost** get a simple ppFSDP forward pass tracing with a functionalized resize_() properly. It also attempts to do the updated version from @jansel, where we don't have any notion of `resize_()` in the graph at all, post functionalization. It would probably be good to test it with @yf225 's FSDP changes, and see how many of the FX passes it allows us to remove. I think that in theory, it should allow us to remove all FX passes that affect the forward graph / partitioner, **except** the one that forces views to be recomputed in the backward (more details below).

There are a few things worth calling out:

(1) failed attempt at functionalizing `aten.copy_()`. I originally wanted to get a version takes these operations:
```
param.storage().resize_(all_gather_size)
param.copy_(all_gather_buffer)
out = aten.matmul(param, param)
```
and functionalizes them into:
```
out = aten.matmul(all_gather_buffer, all_gather_buffer)
```

This would involve getting functionalization to turn `x.copy_(y)` into a giant no-op that just returns `y`. Unfortunately, we can't actually do this in a reasonable way within functionalization (instead, there's a functional `aten.copy` in the graph - see the test case graph expecttest for details). Why? In order for that transformation to be safe, `x` and `y` need to have the same metadata. However, it's possible for `x` and `y` to be subclasses of different types. This is not something we can easily tell from within functionalization, and would be a layering violation. So for now I'm leaving it to downstream code to optimize away the `aten.copy` (this is already the case today, so I think inductor can handle this)

(2) The forward doesn't **actually** run successfully in this PR (see the `assertRaisesRegex` in the test). Why?

The final forward graph looks like this:
```
def forward(self, primals_1, primals_2):
    _foreach_copy = torch.ops.aten._foreach_copy.default([primals_1], [primals_2]);  primals_2 = None
    getitem = _foreach_copy[0];  _foreach_copy = None
    mm = torch.ops.aten.mm.default(getitem, getitem);  getitem = None
    t_1 = torch.ops.aten.t.default(primals_1);  primals_1 = None
    return [mm, t_1]
```

Where `primals_1` starts out as a secretly-zero-storage-size parameter, and gets resized up and back down within the forward (these are functionalized away).

Importantly, the matmul happy on the result of the `foreach_copy`, **but** the activation that we save for backward (`t_1`) is the result of transposing the **original parameter** (the zero-storage-size param). This is exactly the optimization in fsdp that allows us to have good peak memory usage.

The problem is that the min-cut partitioner decides to save `t_1` for backward. Running this code in eager breaks, because the kernel for `aten.permute(x)` is not happy when `x` has secretly-zero-sized-storage.

The real problem here is that in eager mode the `permute` kernel runs during the backward, after backward hooks have properly resized the saved activation. Here, we are running the transpose in the forward.

One option would be to turn off the checks in our view kernels and allow them to work on zero-storage-sized tensors, which feels pretty bad. Another option is to tweak the partitioner (or use one of Will's FX passes) to force the partitioner to not save views for backward, and allow the views to be recomputed in the backward. This seems kind of silly, but is also probably harmless.

(3) The backward is still broken. To be fair, this issue is pretty separable from "functionalizing storage resize calls", and can be fixed later (either by a real fix to our tracing infra, or via another hacky FX pass). More description of this problem is described at issue (8) of my PR description in https://github.com/pytorch/pytorch/pull/120971

(4) I only added support for "full graph" resizing: basically, the limited case where a param starts with zero storage size, and gets resized up and back down. I think we can add support for the graph break case, but I think we can keep that add-on separate from this PR unless we need it immediately. I also added asserts so we should fail loudly when we hit this case

(5) I have a change to FakeTensor creation when inputs have zero storage size that.. is probably ok. But I also removed FakeTensor caching on view ops, which I probably need to fix before I can land this PR

(6) I added a notion of "original_base" to `FunctionalStorageImpl`. More details are in the comments, but my rational for this was that we basically need it to ensure that autograd saves the **original**, zero-storage-sized param for backward, after resizing up and back down

(7) I had to update our eager kernels for `aten.copy` and `aten._foreach_copy`, to handle the case where the `self` argument has secretly-zero-storage. Inductor can probably generate correct code for this case, but we need these ops to work properly in this situation for the `aot_eager` backend to do the right thing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122434
Approved by: https://github.com/jansel
2024-05-10 18:09:10 +00:00
Edward Z. Yang
fcbf2b61e6 Memoize local_scalar_dense calls, refactor all memos (#125623)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125623
Approved by: https://github.com/eellison
2024-05-10 01:52:55 +00:00
Aaron Orenstein
e43d656921 FakeTensor speedup: minor cleanups (#124224)
A few cleanup tasks that didn't really fit into the other diffs in this stack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124224
Approved by: https://github.com/oulgen
ghstack dependencies: #122911, #124223
2024-05-09 22:11:51 +00:00
Aaron Orenstein
a08be4b705 FakeTensor speedup: Split cache_key so we only validate once (#124223)
When dispatching a fake tensor op we cache the result with `(op, args)` as the key. There are some args (such as one with a dynamic output shape) where the output can't be cached. Instead of validating the args every time we compute the cache only validate the args when we first see a new cache key.

18.3% FakeTensor perf win on the microbenchmark (21.7% cumulative)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124223
Approved by: https://github.com/oulgen, https://github.com/masnesral
ghstack dependencies: #122911
2024-05-09 22:11:51 +00:00
Aaron Orenstein
6a8b1da18d FakeTensor speedup: Delay formatting stack trace until it's actually asked for. (#122911)
When constructing a `FakeTensorMode`, instead of immediately formatting a full stack trace, grab the traceback and only format it on demand.

4.2% FakeTensor perf win on the microbenchmark.

```
import time
import torch
import torch._dynamo as dynamo
from torch._subclasses.fake_tensor import FakeTensorMode
import numpy as np

def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    b = b * -1
    return x * b

def run_test1():
    dynamo.reset()
    j = [1, 2, 3]
    toy_example(torch.randn(j), torch.randn(j))

def run_test2():
    dynamo.reset()
    j = [1, 2, 3]
    with FakeTensorMode():
        toy_example(torch.randn(j), torch.randn(j))

ITERATIONS = 500000
FORMAT_STRING = "{name:12}: TOT: {tot:10.3f}, AVG: {avg:10.3f}, MIN: {min:10.3f}, P50: {p50:10.3f}, P90: {p90:10.3f}, P99: {p99:10.3f}"

def run_tests(name, step):
    step()
    timings = []
    start = time.time()
    for i in range(ITERATIONS):
        a = time.perf_counter_ns()
        step()
        b = time.perf_counter_ns()
        timings.append(b - a)
    end = time.time()
    fmt = {
        "best": min(timings),
        "tot": end - start,
        "avg": np.average(timings),
        "min": min(timings),
        "p50": np.percentile(timings, 50),
        "p90": np.percentile(timings, 90),
        "p99": np.percentile(timings, 99)
    }
    print(FORMAT_STRING.format(name=name, **fmt))
    return fmt

ts = run_tests("tensor", run_test1)
fs = run_tests("fake tensor", run_test2)
ratio = {k: a / b for ((k, a), (_, b)) in zip(fs.items(), ts.items())}
print(FORMAT_STRING.format(name="ratio", **ratio))
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122911
Approved by: https://github.com/oulgen, https://github.com/eellison
2024-05-09 22:11:51 +00:00
Aaron Gokaslan
8cad88e1f3 [BE]: Improve exception typing. Remove NOQAs (#125535)
Improve some exception typing

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125535
Approved by: https://github.com/albanD
2024-05-08 14:07:13 +00:00
PyTorch MergeBot
7ffa5558ee Revert "[FX] Update type hints in torch.fx._compatibility.py (#125469)"
This reverts commit 235b4d6ec2.

Reverted https://github.com/pytorch/pytorch/pull/125469 on behalf of https://github.com/izaitsevfb due to breaks pyre in dependent projects (internal: see D56986361) ([comment](https://github.com/pytorch/pytorch/pull/125469#issuecomment-2096665396))
2024-05-06 18:36:43 +00:00
Aaron Gokaslan
1dd42e42c4 [BE]: Try TCH autofixes on torch/ (#125536)
Tries TCH autofixes and see what breaks

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125536
Approved by: https://github.com/ezyang
2024-05-05 23:13:59 +00:00
Xuehai Pan
235b4d6ec2 [FX] Update type hints in torch.fx._compatibility.py (#125469)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125469
Approved by: https://github.com/Skylion007
ghstack dependencies: #125468
2024-05-05 19:30:22 +00:00
Edward Z. Yang
e93b57a570 Add propagate_real_tensors mode for unbacked (#125115)
A common complaint when working with data-dependent code in PyTorch is that it's hard to tell how far you are from the finish line: every time a GuardOnDataDependentSymNode error is hit, you have to somehow fix or workaround it to see the next one.

This PR adds a new mode `torch._functorch.config.fake_tensor_propagate_real_tensors` which modifies fake tensors to also propagate real tensors. This means that when we try to guard on a data-dependent SymNode, we can actually produce a real result. We also produce a warning which you should consult to figure out what the crux points are.

I ran this on vision_maskrcnn. In the baseline (without this mode), the model has 27 graph breaks, resulting in 40 graphs. With this mode on, the model has only 11 graph breaks, resulting in 15 graphs (the remaining graph breaks are due to missing functionality for item() on float tensor and some other Dynamo missing features.) You get a list of things that would have errored like this:

```
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> False
```

Potential later follow ups:

* Improve the warning messages (in particular, should provide user frames)
* GC real tensors when they are no longer needed by tracing. Right now, this will use A LOT of memory, equal to as if your GC was broken and every intermediate tensor was kept live

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125115
Approved by: https://github.com/IvanKobzarev
2024-05-02 15:28:26 +00:00
Avik Chaudhuri
746da8755c switch tests from constrain_as* to torch._check* (#125253)
To fix data-dependent errors we want to recommend that people use `torch._check*` APIs. The `constrain_as*` APIs should be fully subsumed by them, and in the future we should kill them entirely.

Differential Revision: D56774333

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125253
Approved by: https://github.com/ezyang
2024-05-01 21:01:27 +00:00
Edward Z. Yang
c511aed27f [Meta Tensor] fix meta inplace set storage (#123880)
Fixes #123879

Co-authored-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123880
Approved by: https://github.com/ezyang
2024-05-01 06:53:49 +00:00
PyTorch MergeBot
ae13c7e593 Revert "[Meta Tensor] fix meta inplace set storage (#123880)"
This reverts commit cccae93551.

Reverted https://github.com/pytorch/pytorch/pull/123880 on behalf of https://github.com/izaitsevfb due to breaks cpu_inductor_torchbench (detectron2_fasterrcnn) ([comment](https://github.com/pytorch/pytorch/pull/123880#issuecomment-2083366385))
2024-04-29 18:19:42 +00:00
Yuanjing Shi
cccae93551 [Meta Tensor] fix meta inplace set storage (#123880)
Fixes #123879

Co-authored-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123880
Approved by: https://github.com/ezyang
2024-04-28 17:01:12 +00:00
Aaron Orenstein
609c958281 Fix mypy issues in fake_tensor.py (#124428)
fake_tensor.py had mypy error ignored. That seems less than desirable.

Also added SafePyObjectT<T> which is a tagged wrapper around a SafePyObject but provides static type checking (with no other guarantees).

Used `SafePyObjectT<TorchDispatchModeKey>` on some of the TorchDispatchModeTLS API to ensure that we don't accidentally inject a different type than expected into the stack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124428
Approved by: https://github.com/malfet
2024-04-26 15:35:53 +00:00
PyTorch MergeBot
f131c2c199 Revert "Fix mypy issues in fake_tensor.py (#124428)"
This reverts commit 25c0d3f3f0.

Reverted https://github.com/pytorch/pytorch/pull/124428 on behalf of https://github.com/jeanschmidt due to Unfortunately, I needed to revert #123735 and this one depends on it. So please check if there are no merge conflicts or breakages and feel free to merge this PR again ([comment](https://github.com/pytorch/pytorch/pull/124428#issuecomment-2078699836))
2024-04-26 06:15:17 +00:00
Aaron Orenstein
25c0d3f3f0 Fix mypy issues in fake_tensor.py (#124428)
fake_tensor.py had mypy error ignored. That seems less than desirable.

Also added SafePyObjectT<T> which is a tagged wrapper around a SafePyObject but provides static type checking (with no other guarantees).

Used `SafePyObjectT<TorchDispatchModeKey>` on some of the TorchDispatchModeTLS API to ensure that we don't accidentally inject a different type than expected into the stack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124428
Approved by: https://github.com/malfet
2024-04-25 14:07:53 +00:00
Edward Z. Yang
13ab24f192 Reimplement unbacked symbol bindings in Inductor (#124394)
This PR has a lot of "draw the rest of the fucking owl" energy. Here's how to break it down.

1. **torch/_inductor/graph.py** - We start by tightening unbacked symbol invariants. Specifically, as we lower FX nodes, we check whether or not every unbacked_binding recorded on the FX node meta, actually ends up getting bound (according to get_unbacked_symbol_defs) in all the buffers generated by the lowering. Hopefully this invariant is self evident. This leads to a lot of failures.
2. **torch/_inductor/ir.py** - Problem 1: There is softness in how Inductor computes defs of unbacked symbols in IR node. Previously, we tried to infer it by looking at the output sizes/strides/etc and see if new unbacked symbols popped up that we hadn't seen in the inputs. I don't know exactly what was buggy about the old code, but sometimes we would fail to notice an unbacked symbol had been bound, or rebind an unbacked symbol multiple times. Fortunately, thanks to the earlier PRs in our stack, we now have a nice list of unbacked symbol bindings from FX, so we now just store it directly on ExternKernel and use it directly to report defs. This has to be done twice: once for FallbackKernel (e.g., nonzero) and once for DynamicScalar (e.g., item) (see also **torch/_inductor/lowering.py**, **torch/_inductor/codegen/wrapper.py** and  **torch/_inductor/codegen/cpp_wrapper_cpu.py** for the lowering and codegen changes for item)
   * **process_kernel** - Sidequest! It turns out that Inductor lowering can reallocate unbacked symbols. This happens specifically when we repropagate fake tensors through the operator in `process_kernel`. This repropagation process is necessary because Inductor may have changed the strides of input tensors, and it must now recompute the strides so that it can continue to appropriately plan the rest of the lowering process. This is fine: we just make sure we do the rebind unbacked + compute_unbacked_bindings dance we've been doing previously in the PR stack. But instead of putting unbacked_bindings on a new FX node, they go straight into our unbacked_bindings on the Inductor IR node.
    * **codegen_unbacked_symbol_defs** - Sidequest! FallbackKernel lowering is done in two steps. First, you emit the FallbackKernel buffer. Then, you emit MultiOutput buffers which actually give access to the individual outputs of FallbackKernel, which may have been multi-output. There is a design decision here: does the FallbackKernel bind the unbacked symbols, or the MultiOutput buffer? Historically, we put the binding on MultiOutput buffer, because it's more convenient: the FallbackKernel buffer is fake, in fact, it doesn't even get a name in C++ codegen. But it's kind of inconsistent with the keypath model that we've been tracking unbacked bindings with: if you have a multi-output node, you'd expect a keypath like `[0].size()[0]` representing the first output's first dimension size. That suggests that it's the FallbackKernel that should define the things. So that was my first implementation. Unfortunately, the C++ codegen is too cursed and I could not understand how to make it work in that case. So now we just unsoundly assume you cannot have multi-output data dependent output, and do the codegen in MultiOutput. There are some comments explaining exactly what we are improperly assuming.
3. **_rename_unbacked_to** in **torch/fx/experimental/symbolic_shapes.py** - Previously, when we renamed unbacked symbols, we clobbered any facts we previously knew about them. So for example, if we had a replacement `u0 -> s0` but then we renamed u0 to u1, we would now setup the replacement `u0 -> u1`, clobbering the old replacement. This apparently didn't matter in earlier PRs in the stack, but with Inductor now on the ball, there were some tests that indicated this was a problem. The solution is easy: if u0 had a preexisting replacement, reapply it to u1. However...
    * **torch/_functorch/_aot_autograd/collect_metadata_analysis.py** - When we run forward analysis, this triggers fake tensor repropagation and fresh allocations. Previously, we just cleared out the pending symbols when finished the analysis. But with the change above, this would also migrate replacements to the new symbols... which are now dead. So now we explicitly suppress generation of these symbols with `ignore_fresh_unbacked_symbols` so that no rebinding happens at all.
    * **torch/_dynamo/eval_frame.py** - same deal; I just searched for all sites we called clear() on pending
4. The last step is fixing the long tail of extra problems that show up, now that unbacked_bindings are load bearing into Inductor
    * **torch/_dynamo/eval_frame.py** - Some of the exports are making copies of nodes without repropagating fake tensors, so in this case, it is important to also copy the `unbacked_bindings` (apparently this didn't matter before without the Inductor changes)
    * **torch/_export/pass_base.py** - I discover that this is doing fake tensor repropagation via a test suite failure. Do the same playbook as AOTAutograd: PropagateUnbackedSymInts too!  Actually, they also have implemented their own tracer as well, so do the same playbook as proxy_tensor: record unbacked_bindings on the newly traced nodes. UGH code duplication.
    * **torch/_subclasses/fake_tensor.py**, **torch/_subclasses/fake_impls.py** (with call site updates at  **torch/_functorch/_aot_autograd/traced_function_transforms.py** and **torch/fx/passes/fake_tensor_prop.py**) - What's this new epoch thing? I noticed that sometimes I would be retracing, call nonzero() on a fake tensor, and not allocate a new unbacked symbol. This is actually bad, because if I don't get a new unbacked symbol, I don't know there's a binding site, and `unbacked_bindings` is now missing a binding. The reason for this is memoization: if I reuse the exact same fake tensor on my retrace, it will already have an unbacked symint memoized on it and we will short circuit allocation. Well, that's no good. So I associate the memos with a fake tensor epoch, and every time you start a new fake tensor propagation from scratch, you bump the epoch so that I clear all the memos.
    * **torch/_inductor/scheduler.py** - I notice in unit tests that V.current_node is not always set when we call process_kernel. So I save it into the IR node and restore it when we are running `get_estimated_runtime`.
    * **torch/fx/experimental/symbolic_shapes.py** - A few things
      * **rebind_unbacked** (re **_tensor_version**). Ordinarily, when you have an unbacked SymInt, you persistently hvae it all the way to the end of the program. `_tensor_version` violates this: this generates an unbacked SymInt (for reasons I don't quite understand?) and then gets rid of it later. This triggered an assert violation. I think this op is kind of misusing unbacked SymInt, but I didn't know how to refactor it, so it gets a special case.
      * **rebind_unbacked** (re **Simplify SymBool binding**). Ugh, SymBool, what a pain in the butt. I have an assert that you can only rebind unbacked symbol to another unbacked symbol. This assert fails when a boolean is involved, because the result of running keypath on the result is not `u1`, it's `sympy.Piecewise(... sympy.Eq(u1, 1) ...)`. This is actually just `u1`, but Sympy doesn't know it because it doesn't know that `u1` value range is `[0, 1]`. So we manually implement the simplification needed to get the assert to pass.
      * **compute_unbacked_bindings** (re **This is pretty fragile**). There is a really funny disaster involving memoization and Inductor process kernel. Ordinarily when I retrace, if there was a memo hit in the old trace, there will be a memo hit in the new trace. However, Inductor process kernel breaks this, because it recreates fake tensor inputs to the operator call from scratch (since they might have different strides), and obviously these tensor inputs don't have the memo from the old one. I tried a little bit to try to manually transplant the memo to the new fake tensor but it seemed hopeless, so I just let the fresh symbol ride, allocating a new unbacked symbol. However, in one of our tests, we rely on knowing that the first nonzero call is equal to the second (memoized) nonzero call. The equality test looked pretty easy to discharge, so I just went ahead and added a deferred runtime assert to this effect and it worked.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124394
Approved by: https://github.com/jansel
ghstack dependencies: #124310, #124314, #124316
2024-04-25 02:08:59 +00:00
Edward Z. Yang
a22847a9cb We should not be in kernel invocation before we restore fake mode (#124762)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124762
Approved by: https://github.com/eellison
ghstack dependencies: #124760
2024-04-24 20:32:59 +00:00
Edward Z. Yang
0d58aeb73a Handle size/etc accessors in FakeTensor, support accessing symbolic types from toInt/etc in IValue (#124760)
Fixes https://github.com/pytorch/pytorch/issues/122772

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124760
Approved by: https://github.com/albanD, https://github.com/eellison
2024-04-24 20:32:59 +00:00
Aaron Gokaslan
c5fafe9f48 [BE]: TRY002 - Ban raising vanilla exceptions (#124570)
Adds a ruff lint rule to ban raising raw exceptions. Most of these should at the very least be runtime exception, value errors, type errors or some other errors. There are hundreds of instance of these bad exception types already in the codebase, so I have noqa'd most of them. Hopefully this error code will get commiters to rethink what exception type they should raise when they submit a PR.

I also encourage people to gradually go and fix all the existing noqas that have been added so they can be removed overtime and our exception typing can be improved.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124570
Approved by: https://github.com/ezyang
2024-04-21 22:26:40 +00:00
rzou
889e3eeed3 Avoid cuda init to FakeTensorMode (#124413)
Also partially fixes #122109

This PR:
- We add a C++ flag (only_lift_cpu_tensors) to toggle the
  torch.tensor(1, device='cuda') ctor strategy.
  When false (default), it does the current PyTorch behavior
  of unconditionally constructing a concrete CUDA tensor then calling
  lift_fresh on it. When true, we instead construct a concrete CPU
  tensor, call lift_fresh, and then call Tensor.to(device) (under any ambient
  modes).
- FakeTensorMode flips this flag depending on if CUDA is available or
  not. We don't unconditionally set the flag to True because that is
  likely BC-breaking.

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124413
Approved by: https://github.com/eellison
2024-04-19 02:39:35 +00:00
Tugsbayasgalan Manlaibaatar
d23bf9cef0 Add fake impl for aten.unique2 (#124306)
Reapply of: https://github.com/pytorch/pytorch/pull/121571
Differential Revision: [D56258431](https://our.internmc.facebook.com/intern/diff/D56258431)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124306
Approved by: https://github.com/gmagogsfm
2024-04-17 22:55:27 +00:00
Aaron Orenstein
2bcc83dfbd Preserve dispatch state across function tracing (#122073)
If we throw an exception in the "wrong" place we can end up with the dispatch state being in a weird state which can cause all future dispatching to fail. Preserve and restore it as part of `preserve_global_state` so we know it's sane after that.

Also fake_tensor's in_kernel_invocation_manager() was leaving a bit set in the dispatcher (DispatchKey.Dense) which affected follow-on code.  Fixed that to reset after as well.

Repro:

before:
```
$ rm test/dynamo_skips/TestSparseCPU.test_to_dense_with_gradcheck_sparse_cpu_complex64
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest -s test/dynamo/test_export.py test/test_sparse.py -k 'test_to_dense_with_gradcheck_sparse_cpu_complex64'
======== 1 passed, 6173 deselected in 5.21s =============
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest -s test/dynamo/test_export.py test/test_sparse.py -k 'test_torch_inference_mode_ctx or test_to_dense_with_gradcheck_sparse_cpu_complex64'
========= 1 skipped, 6172 deselected, 1 error in 5.29s =========
```
(note that test_to_dense_with_gradcheck_sparse_cpu_complex64 passes on its own but failed when including the skipped test_export.py tests)
after:
```
$ rm test/dynamo_skips/TestSparseCPU.test_to_dense_with_gradcheck_sparse_cpu_complex64
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest -s test/dynamo/test_export.py test/test_sparse.py -k 'test_to_dense_with_gradcheck_sparse_cpu_complex64'
===================== 1 passed, 6173 deselected in 5.42s =====================
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest -s test/dynamo/test_export.py test/test_sparse.py -k 'test_torch_inference_mode_ctx or test_to_dense_with_gradcheck_sparse_cpu_complex64'
===================== 1 passed, 1 skipped, 6172 deselected in 7.30s ======================
```
(note that test_to_dense_with_gradcheck_sparse_cpu_complex64 passes in both runs)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122073
Approved by: https://github.com/zou3519
2024-04-10 18:57:01 +00:00
rzou
81e7a7c955 Add mutated_args field to custom_op (#123129)
If provided, we:
- autogenerate an ADInplaceOrView implementation
- assume that no mutated inputs are returned as outputs. There are
  already aliasing runtime checks that check this.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123129
Approved by: https://github.com/albanD
ghstack dependencies: #123108, #123109, #123110
2024-04-05 22:03:51 +00:00
rzou
d486cb7c1b Deprecate calling FakeTensor.data_ptr in eager-mode (#123292)
Today, we error out on FakeTensor.data_ptr under torch.compile. This PR
moves to error out on FakeTensor.data_ptr under eager mode to avoid
diverging behavior.

We do this by adding another bit onto FakeTensor that we'll remove after
the deprecation cycle.

Test Plan:
- tested locally
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123292
Approved by: https://github.com/eellison
ghstack dependencies: #123261, #123282, #123291
2024-04-04 20:35:24 +00:00
rzou
fd60752786 Turn _allow_unsafe_data_ptr_access into a config option (#123291)
We're not planning on having this flag around for very long (see
deprecation in next PR), so it's better as a config option.

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123291
Approved by: https://github.com/eellison
ghstack dependencies: #123261, #123282
2024-04-04 20:35:24 +00:00
ydwu4
c77352b5cc Add torch._library.register_fake_class to fakify torchBind class (#122622)
This PR only adds abstract class registration logic without touching existing tests so they still trace with real script object. The added tests are only for registration APIs and test error messages.

Our design is that the abstract implementation should be in Python. This is much better in terms of usability. But this also has implications for custom op that takes script object as input, which is detailed later in this stack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122622
Approved by: https://github.com/zou3519
ghstack dependencies: #122619, #122620, #122621
2024-04-02 23:52:17 +00:00
rzou
c81c9ba472 Disallow {FakeTensor,FunctionalTensor}.data_ptr (#122514)
This PR:
- disallows FakeTensor.data_ptr when it is called inside PT2 or fx tracing.
- disallows FunctionalTensor.data_ptr (python FunctionalTensor is only used in
  PT2)

The motivation behind this is that the leading cause of segfaults when
using custom ops with PT2 is calling .data_ptr on FunctionalTensor or
FakeTensor.

This change is BC-breaking. If your code broke as a result of this, it's
because there was a bug in it (these .data_ptr should never be
accessed!). You can either fix the bug (recommended) or get the previous
behavior back with:
```
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor

data_ptr = 0 if isinstance(tensor, (FakeTensor, FunctionalTensor)) else tensor.data_ptr()
```

Test Plan:
- existing tests

Differential Revision: [D55366199](https://our.internmc.facebook.com/intern/diff/D55366199)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122514
Approved by: https://github.com/ezyang, https://github.com/albanD, https://github.com/yifuwang, https://github.com/kurtamohler
2024-03-26 23:55:42 +00:00
Edward Z. Yang
85845a29db Refactor ShapeEnvSettings so it's directly on ShapeEnv (#122310)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122310
Approved by: https://github.com/masnesral, https://github.com/lezcano
2024-03-26 14:16:33 +00:00
Edward Z. Yang
268b0cc714 Do not run CUDA lazy init if it is triggered with fake mode on. (#122636)
Partially fixes https://github.com/pytorch/pytorch/issues/122109

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122636
Approved by: https://github.com/zou3519
2024-03-26 05:43:59 +00:00
Edward Z. Yang
47a9725de9 Implement prefer_deferred_runtime_asserts_over_guards (#122090)
Fixes https://github.com/pytorch/pytorch/issues/121749

As promised, it is pretty easy.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122090
Approved by: https://github.com/lezcano
2024-03-25 16:31:16 +00:00
Edward Z. Yang
1e404c9b12 Remove redundant query to tensor_to_context (#122278)
from_real_tensor will query it again, so this query is strictly
dominated.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122278
Approved by: https://github.com/eellison
ghstack dependencies: #122044, #122270, #122271
2024-03-25 13:16:21 +00:00
Edward Z. Yang
49b81af45f Delete dead memoized_only kwarg in FakeTensor (#122271)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122271
Approved by: https://github.com/eellison
ghstack dependencies: #122044, #122270
2024-03-25 13:16:21 +00:00
Edward Z. Yang
f32ce4e28e Delete FakeTensorConverter.__call__ in favor of from_real_tensor (#122270)
It's annoying grepping for `__call__` call-sites so they're now all explicit now. I'd do this to MetaConverter too but that one is way more public, a lot more sites.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122270
Approved by: https://github.com/eellison
ghstack dependencies: #122044
2024-03-25 13:16:13 +00:00
Edward Z. Yang
5891c5b3a6 Factor meta conversion through serializable MetaTensorDesc (#122044)
Fixes https://github.com/pytorch/pytorch/issues/121085

This PR pretty involved so pay attention to this description.  At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.

However, this transformation is not always entirely mechanical.  Here
is what you need to pay attention to:

- The memo table from real Tensor -> meta/fake Tensor is now broken
  into two memo tables: real Tensor -> stable int id -> meta/fake
  Tensor.  The stable int id is needed so that when we do serialization,
  we know when tensors/storages alias each other and can ensure we preserve
  this aliasing upon deserialization.

  The way I have implemented changes the weak reference behavior.
  Previously, when either the real Tensor OR the meta/fake Tensor went
  dead, we would remove the entry from the memo table.  Now, this only
  removes entries from one of the two memo tables.  This semantically
  makes sense, because the user may have held on to the stable int id
  out of band, and may expect a real Tensor to continue to be numbered
  consistently / expect to be able to lookup a meta/fake tensor from
  this id.  If this is unacceptable, it may be possible to rejigger
  the memo tables so that we have real Tensor -> stable int id
  and real Tensor -> meta/fake Tensor, but TBH I find the new
  implementation a lot simpler, and arranging the memo tables in this
  way means that I have to muck around with the real tensor to save
  to the memo table; in the current implementation, I never pass the
  Tensor to meta_tensor function AT ALL, which means it is impossible
  to accidentally depend on it.

- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
  to be careful not to poke fields when they are not valid.  Previously,
  preconditions were implicitly checked via the conditional structure
  ("is this sparse? is this nested?") that is tested before we start
  reading attributes.  This structure has to be replicated in
  describe_tensor, and I have almost assuredly gotten it wrong on my
  first try (I'll be grinding through it on CI; a careful audit will
  help too, by auditing that I've tested all the same conditionals that
  the original access was guarded by.)

- I originally submitted https://github.com/pytorch/pytorch/pull/121821
  for the symbolic shapes change, but it turned out the way I did it
  there didn't actually work so well for this PR.  I ended up just
  inlining the symbolic shapes allocation logic into MetaConverter
  (look for calls to maybe_specialize_sym_int_with_hint), maybe there
  is a better way to structure it, but what I really want is to
  just read sizes/strides/offset directly off of MetaTensorDesc; I
  don't want another intermediate data structure.

- Some fields aren't serializable. These are documented as "NOT
  serializable".  ctx/type should morally be serializable and I just
  need to setup a contract with subclasses to let them be serialized.
  The fake_mode is used solely to test if we are refakefying with
  a pre-existing ShapeEnv and we want to reuse the SymInt
  directly--serializing this case is hopeless but I am kind of hoping
  after this refactor we do not need this at all.  view_func is not
  serializable because it's a bound C implemented method.  Joel has
  promised me that this is not too difficult to actually expose as a
  true data structure, but this is the edgiest of edge cases and there
  is no reason to deal with it right now.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122044
Approved by: https://github.com/eellison
2024-03-25 06:21:17 +00:00
PyTorch MergeBot
f65373e278 Revert "Factor meta conversion through serializable MetaTensorDesc (#122044)"
This reverts commit e2d89e9704.

Reverted https://github.com/pytorch/pytorch/pull/122044 on behalf of https://github.com/jeanschmidt due to Seems that some landrace caused this PR to break lint ([comment](https://github.com/pytorch/pytorch/pull/122044#issuecomment-2015025490))
2024-03-22 12:46:21 +00:00
Edward Z. Yang
e2d89e9704 Factor meta conversion through serializable MetaTensorDesc (#122044)
Fixes https://github.com/pytorch/pytorch/issues/121085

This PR pretty involved so pay attention to this description.  At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.

However, this transformation is not always entirely mechanical.  Here
is what you need to pay attention to:

- The memo table from real Tensor -> meta/fake Tensor is now broken
  into two memo tables: real Tensor -> stable int id -> meta/fake
  Tensor.  The stable int id is needed so that when we do serialization,
  we know when tensors/storages alias each other and can ensure we preserve
  this aliasing upon deserialization.

  The way I have implemented changes the weak reference behavior.
  Previously, when either the real Tensor OR the meta/fake Tensor went
  dead, we would remove the entry from the memo table.  Now, this only
  removes entries from one of the two memo tables.  This semantically
  makes sense, because the user may have held on to the stable int id
  out of band, and may expect a real Tensor to continue to be numbered
  consistently / expect to be able to lookup a meta/fake tensor from
  this id.  If this is unacceptable, it may be possible to rejigger
  the memo tables so that we have real Tensor -> stable int id
  and real Tensor -> meta/fake Tensor, but TBH I find the new
  implementation a lot simpler, and arranging the memo tables in this
  way means that I have to muck around with the real tensor to save
  to the memo table; in the current implementation, I never pass the
  Tensor to meta_tensor function AT ALL, which means it is impossible
  to accidentally depend on it.

- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
  to be careful not to poke fields when they are not valid.  Previously,
  preconditions were implicitly checked via the conditional structure
  ("is this sparse? is this nested?") that is tested before we start
  reading attributes.  This structure has to be replicated in
  describe_tensor, and I have almost assuredly gotten it wrong on my
  first try (I'll be grinding through it on CI; a careful audit will
  help too, by auditing that I've tested all the same conditionals that
  the original access was guarded by.)

- I originally submitted https://github.com/pytorch/pytorch/pull/121821
  for the symbolic shapes change, but it turned out the way I did it
  there didn't actually work so well for this PR.  I ended up just
  inlining the symbolic shapes allocation logic into MetaConverter
  (look for calls to maybe_specialize_sym_int_with_hint), maybe there
  is a better way to structure it, but what I really want is to
  just read sizes/strides/offset directly off of MetaTensorDesc; I
  don't want another intermediate data structure.

- Some fields aren't serializable. These are documented as "NOT
  serializable".  ctx/type should morally be serializable and I just
  need to setup a contract with subclasses to let them be serialized.
  The fake_mode is used solely to test if we are refakefying with
  a pre-existing ShapeEnv and we want to reuse the SymInt
  directly--serializing this case is hopeless but I am kind of hoping
  after this refactor we do not need this at all.  view_func is not
  serializable because it's a bound C implemented method.  Joel has
  promised me that this is not too difficult to actually expose as a
  true data structure, but this is the edgiest of edge cases and there
  is no reason to deal with it right now.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122044
Approved by: https://github.com/eellison
ghstack dependencies: #122018
2024-03-22 03:56:34 +00:00
Wenting Wang
02bb2180f4 [torch export] replace traceback.extract_stack with CapturedTraceback.extract (#121449)
Summary:
with a simple bench in TestDeserializer.test_basic function:
```
time_start = time.time()
for i in range(1000):
    self.check_graph(MyModule(), inputs)
warnings.warn(f"time_taken: {time.time() - time_start}")
```
and forcing FakeTensorConfig.debug to True, record_stack_traces to True, logging level to debug, it shows that the the changed code is consistently ard 20 secs faster (~90s vs originally ~110s)

Test Plan:
test passed, see summary

compared debug trace before and after:
- exactly the same for fake tensor and proxy callsite https://www.internalfb.com/intern/diffing/?paste_number=1189883685
- slightly different for the user frame in proxy node https://www.internalfb.com/intern/diffing/?paste_number=1189884347

Differential Revision: D54237017

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121449
Approved by: https://github.com/angelayi
2024-03-13 00:19:05 +00:00
albanD
6791b0c09e Change default torch_function behavior to be disabled when torch_dispatch is defined (take 2) (#120632)
This does not introduce a new test but is tested by checking that all the classes we already have still behave as before now that they don't explicitly disable torch_function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120632
Approved by: https://github.com/ezyang
2024-03-09 01:08:37 +00:00
Aaron Orenstein
edd80f87b8 Prevent infinite recursion within Tensor.__repr__ (#120206)
`Tensor.__repr__` calls functions which can perform logging which ends up logging `self` (with `__repr__`) causing an infinite loop. Instead of logging all the args in FakeTensor.dispatch log the actual parameters (and use `id` to log the tensor itself).

The change to torch/testing/_internal/common_utils.py came up during testing - in some ways of running the test parts was `('test', 'test_testing.py')` and so `i` was 0 and we were doing a join on `()` which was causing an error.

Repro:
```
import torch
from torch.testing import make_tensor
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
t = torch.sparse_coo_tensor(((0, 1), (1, 0)), (1, 2), size=(2, 2))
t2 = FakeTensor.from_tensor(t, FakeTensorMode())
print(repr(t2))
```
and run with `TORCH_LOGS=+all`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120206
Approved by: https://github.com/yanboliang, https://github.com/pearu
2024-03-07 02:24:45 +00:00
Pearu Peterson
ce2903080c Add sparse compressed fake tensor support (#120920)
As in the title.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120920
Approved by: https://github.com/ezyang
2024-03-04 14:38:45 +00:00
Pearu Peterson
b8e6ca6f76 Add sparse compressed meta tensor support (#120707)
As in the title.

Replaces https://github.com/pytorch/pytorch/pull/120498 and https://github.com/pytorch/pytorch/pull/120562

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120707
Approved by: https://github.com/ezyang
ghstack dependencies: #120703
2024-03-01 13:28:47 +00:00
James Wu
a911eb74ae [dynamo] Graph break when faking named tensors (#120779)
Fixes #120644
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120779
Approved by: https://github.com/zou3519
2024-02-29 18:22:15 +00:00
Guilherme Leobas
491c2b4665 Let torch dynamo inline torch.func.grad (#118407)
When dynamo sees torch.func.grad, it tries to inline all frames related
to.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118407
Approved by: https://github.com/zou3519
2024-02-28 20:05:00 +00:00
Isuru Fernando
b7df3bba62 add decomposition for frexp (#119217)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119217
Approved by: https://github.com/peterbell10
ghstack dependencies: #119284, #120027
2024-02-23 21:52:42 +00:00
Thiago Crepaldi
761fa5d6ec Add FakeTensor support to torch._utils._rebuild_tensor (#108186)
There are two scenarios:

* Scenario 1: The checkpoint was saved with pytorch < 1.6
* Scenario 2: The checkpoint was saved with pytorch >= 1.6

Repro Scenario 1:

```python
from torch._subclasses import fake_tensor
import transformers

fake_mode = fake_tensor.FakeTensorMode()
with fake_mode:
    fake_model = transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2")
```

Error:

```bash
Some weights of the model checkpoint at sshleifer/tiny-gpt2 were not used when initializing GPT2Model: ['lm_head.weight']
- This IS expected if you are initializing GPT2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py:463 in           │
│ load_state_dict                                                                                  │
│                                                                                                  │
│    460 │   │   │   )                                                                             │
│    461 │   │   return safe_load_file(checkpoint_file)                                            │
│    462 │   try:                                                                                  │
│ ❱  463 │   │   return torch.load(checkpoint_file, map_location="cpu")                            │
│    464 │   except Exception as e:                                                                │
│    465 │   │   try:                                                                              │
│    466 │   │   │   with open(checkpoint_file) as f:                                              │
│                                                                                                  │
│ /opt/pytorch/torch/serialization.py:1030 in load                                                 │
│                                                                                                  │
│   1027 │   │   │   │   return _legacy_load(opened_file, map_location, _weights_only_unpickler,   │
│   1028 │   │   │   except RuntimeError as e:                                                     │
│   1029 │   │   │   │   raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None           │
│ ❱ 1030 │   │   return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args  │
│   1031                                                                                           │
│   1032                                                                                           │
│   1033 # Register pickling support for layout instances such as                                  │
│                                                                                                  │
│ /opt/pytorch/torch/serialization.py:1258 in _legacy_load                                         │
│                                                                                                  │
│   1255 │   _sys_info = pickle_module.load(f, **pickle_load_args)                                 │
│   1256 │   unpickler = UnpicklerWrapper(f, **pickle_load_args)                                   │
│   1257 │   unpickler.persistent_load = persistent_load                                           │
│ ❱ 1258 │   result = unpickler.load()                                                             │
│   1259 │                                                                                         │
│   1260 │   deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)                 │
│   1261                                                                                           │
│                                                                                                  │
│ /opt/pytorch/torch/_utils.py:201 in _rebuild_tensor_v2                                           │
│                                                                                                  │
│   198 def _rebuild_tensor_v2(                                                                    │
│   199 │   storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None    │
│   200 ):                                                                                         │
│ ❱ 201 │   tensor = _rebuild_tensor(storage, storage_offset, size, stride)                        │
│   202 │   tensor.requires_grad = requires_grad                                                   │
│   203 │   if metadata:                                                                           │
│   204 │   │   set_tensor_metadata(tensor, metadata)                                              │
│                                                                                                  │
│ /opt/pytorch/torch/_utils.py:180 in _rebuild_tensor                                              │
│                                                                                                  │
│   177 def _rebuild_tensor(storage, storage_offset, size, stride):                                │
│   178 │   # first construct a tensor with the correct dtype/device                               │
│   179 │   t = torch.tensor([], dtype=storage.dtype, device=storage._untyped_storage.device)      │
│ ❱ 180 │   return t.set_(storage._untyped_storage, storage_offset, size, stride)                  │
│   181                                                                                            │
│   182                                                                                            │
│   183 def get_tensor_metadata(tensor):                                                           │
│                                                                                                  │
│ /opt/pytorch/torch/utils/_stats.py:20 in wrapper                                                 │
│                                                                                                  │
│   17 │   │   if fn.__qualname__ not in simple_call_counter:                                      │
│   18 │   │   │   simple_call_counter[fn.__qualname__] = 0                                        │
│   19 │   │   simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1     │
│ ❱ 20 │   │   return fn(*args, **kwargs)                                                          │
│   21 │   return wrapper                                                                          │
│   22                                                                                             │
│                                                                                                  │
│ /opt/pytorch/torch/_subclasses/fake_tensor.py:1160 in __torch_dispatch__                         │
│                                                                                                  │
│   1157 │   def __torch_dispatch__(self, func, types, args=(), kwargs=None):                      │
│   1158 │   │   assert self not in _get_current_dispatch_mode_stack(), func                       │
│   1159 │   │   try:                                                                              │
│ ❱ 1160 │   │   │   return self.dispatch(func, types, args, kwargs)                               │
│   1161 │   │   except TypeError:                                                                 │
│   1162 │   │   │   log.exception("fake tensor raised TypeError")                                 │
│   1163 │   │   │   raise                                                                         │
│                                                                                                  │
│ /opt/pytorch/torch/_subclasses/fake_tensor.py:1318 in dispatch                                   │
│                                                                                                  │
│   1315 │   │                                                                                     │
│   1316 │   │   # we are falling through to running non constant tensors, any input constant tha  │
│   1317 │   │   # is written to must be invalidated                                               │
│ ❱ 1318 │   │   self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)   │
│   1319 │   │                                                                                     │
│   1320 │   │   # Try for fastpath                                                                │
│   1321 │   │   if has_symbolic_sizes:                                                            │
│                                                                                                  │
│ /opt/pytorch/torch/_subclasses/fake_tensor.py:1557 in invalidate_written_to_constants            │
│                                                                                                  │
│   1554 │   │   any_constant = any(e.constant is not None for e in flat_arg_fake_tensors)         │
│   1555 │   │   if any_constant and get_schema_info(func).is_mutable():                           │
│   1556 │   │   │   schema_info = get_schema_info(func)                                           │
│ ❱ 1557 │   │   │   _, new_kwargs = normalize_function(                                           │
│   1558 │   │   │   │   func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True         │
│   1559 │   │   │   )                                                                             │
│   1560 │   │   │   for k, v in new_kwargs.items():                                               │
│                                                                                                  │
│ /opt/pytorch/torch/fx/operator_schemas.py:297 in normalize_function                              │
│                                                                                                  │
│   294 │   │   new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs,    │
│   295 │   else:                                                                                  │
│   296 │   │   assert callable(target)                                                            │
│ ❱ 297 │   │   torch_op_schemas = get_signature_for_torch_op(target)                              │
│   298 │   │   matched_schemas = []                                                               │
│   299 │   │   if torch_op_schemas:                                                               │
│   300 │   │   │   # Iterate through all of the schema until we find one that matches             │
│                                                                                                  │
│ /opt/pytorch/torch/fx/operator_schemas.py:167 in get_signature_for_torch_op                      │
│                                                                                                  │
│   164 │   │   │   return (None, None) if return_schemas else None                                │
│   165 │   │   schemas = torch._C._jit_get_schemas_for_operator(aten_fn)                          │
│   166 │                                                                                          │
│ ❱ 167 │   signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]          │
│   168 │   return (signatures, schemas) if return_schemas else signatures                         │
│   169                                                                                            │
│   170 @compatibility(is_backward_compatible=False)                                               │
│                                                                                                  │
│ /opt/pytorch/torch/fx/operator_schemas.py:167 in <listcomp>                                      │
│                                                                                                  │
│   164 │   │   │   return (None, None) if return_schemas else None                                │
│   165 │   │   schemas = torch._C._jit_get_schemas_for_operator(aten_fn)                          │
│   166 │                                                                                          │
│ ❱ 167 │   signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]          │
│   168 │   return (signatures, schemas) if return_schemas else signatures                         │
│   169                                                                                            │
│   170 @compatibility(is_backward_compatible=False)                                               │
│                                                                                                  │
│ /opt/pytorch/torch/fx/operator_schemas.py:70 in _torchscript_schema_to_signature                 │
│                                                                                                  │
│    67 │   from inspect import Parameter                                                          │
│    68 │   parameters : List[Parameter] = []                                                      │
│    69 │   for arg in ts_schema.arguments:                                                        │
│ ❱  70 │   │   arg_type = _torchscript_type_to_python_type(arg.type)                              │
│    71 │   │   default = arg.default_value if arg.has_default_value() else Parameter.empty        │
│    72 │   │   # TODO: Figure out if this is safe. It seems like when generating the type signa   │
│    73 │   │   # PythonArgParser, we emit signatures with `input` instead of `self` as the firs   │
│                                                                                                  │
│ /opt/pytorch/torch/fx/operator_schemas.py:64 in _torchscript_type_to_python_type                 │
│                                                                                                  │
│    61 │   eval'ing the annotation_str. _type_eval_globals sets up expressions                    │
│    62 │   like "List" and "Future" to map to actual types (typing.List and jit.Future)           │
│    63 │   """                                                                                    │
│ ❱  64 │   return eval(ts_type.annotation_str, _type_eval_globals)                                │
│    65                                                                                            │
│    66 def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Sig   │
│    67 │   from inspect import Parameter                                                          │
│ <string>:1 in <module>                                                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
NameError: name 'Storage' is not defined

During handling of the above exception, another exception occurred:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py:467 in           │
│ load_state_dict                                                                                  │
│                                                                                                  │
│    464 │   except Exception as e:                                                                │
│    465 │   │   try:                                                                              │
│    466 │   │   │   with open(checkpoint_file) as f:                                              │
│ ❱  467 │   │   │   │   if f.read(7) == "version":                                                │
│    468 │   │   │   │   │   raise OSError(                                                        │
│    469 │   │   │   │   │   │   "You seem to have cloned a repository without having git-lfs ins  │
│    470 │   │   │   │   │   │   "git-lfs and run `git lfs install` followed by `git lfs pull` in  │
│                                                                                                  │
│ /opt/conda/envs/ptca/lib/python3.8/codecs.py:322 in decode                                       │
│                                                                                                  │
│    319 │   def decode(self, input, final=False):                                                 │
│    320 │   │   # decode input (taking the buffer into account)                                   │
│    321 │   │   data = self.buffer + input                                                        │
│ ❱  322 │   │   (result, consumed) = self._buffer_decode(data, self.errors, final)                │
│    323 │   │   # keep undecoded input until the next call                                        │
│    324 │   │   self.buffer = data[consumed:]                                                     │
│    325 │   │   return result                                                                     │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 0: invalid start byte

During handling of the above exception, another exception occurred:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /opt/pytorch/bug_repro.py:16 in <module>                                                         │
│                                                                                                  │
│   13 fake_model = transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2")                  │
│   14 assert fake_model is not None                                                               │
│   15 with fake_mode:                                                                             │
│ ❱ 16 │   fake_model = transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2")  # raises    │
│                                                                                                  │
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py:484 in │
│ from_pretrained                                                                                  │
│                                                                                                  │
│   481 │   │   │   )                                                                              │
│   482 │   │   elif type(config) in cls._model_mapping.keys():                                    │
│   483 │   │   │   model_class = _get_model_class(config, cls._model_mapping)                     │
│ ❱ 484 │   │   │   return model_class.from_pretrained(                                            │
│   485 │   │   │   │   pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs,   │
│   486 │   │   │   )                                                                              │
│   487 │   │   raise ValueError(                                                                  │
│                                                                                                  │
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py:2604 in          │
│ from_pretrained                                                                                  │
│                                                                                                  │
│   2601 │   │   if from_pt:                                                                       │
│   2602 │   │   │   if not is_sharded and state_dict is None:                                     │
│   2603 │   │   │   │   # Time to load the checkpoint                                             │
│ ❱ 2604 │   │   │   │   state_dict = load_state_dict(resolved_archive_file)                       │
│   2605 │   │   │                                                                                 │
│   2606 │   │   │   # set dtype to instantiate the model under:                                   │
│   2607 │   │   │   # 1. If torch_dtype is not None, we use that dtype                            │
│                                                                                                  │
│ /opt/conda/envs/ptca/lib/python3.8/site-packages/transformers/modeling_utils.py:479 in           │
│ load_state_dict                                                                                  │
│                                                                                                  │
│    476 │   │   │   │   │   │   "model. Make sure you have saved the model properly."             │
│    477 │   │   │   │   │   ) from e                                                              │
│    478 │   │   except (UnicodeDecodeError, ValueError):                                          │
│ ❱  479 │   │   │   raise OSError(                                                                │
│    480 │   │   │   │   f"Unable to load weights from pytorch checkpoint file for '{checkpoint_f  │
│    481 │   │   │   │   f"at '{checkpoint_file}'. "                                               │
│    482 │   │   │   │   "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please s  │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
OSError: Unable to load weights from pytorch checkpoint file for '/root/.cache/huggingface/hub/models--sshleifer--tiny-gpt2/snapshots/5f91d94bd9cd7190a9f3216ff93cd1dd95f2c7be/pytorch_model.bin' at
'/root/.cache/huggingface/hub/models--sshleifer--tiny-gpt2/snapshots/5f91d94bd9cd7190a9f3216ff93cd1dd95f2c7be/pytorch_model.bin'. If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set
from_tf=True.
```

Repro scenario 2:

```python
import tempfile
import torch
from torch._subclasses import fake_tensor

class TheModelClass(torch.nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.fc1 = torch.nn.Linear(5, 10)

    def forward(self, x):
        return self.fc1(x)

with tempfile.NamedTemporaryFile() as state_dict_file:
    # Create state_dict to be loaded later
    model = TheModelClass()
    torch.save(model.state_dict(), state_dict_file.name)

    fake_mode = fake_tensor.FakeTensorMode()
    with fake_mode:
        # This is where the bug is triggered
        state_dict = torch.load(state_dict_file.name)
```

Error:

```bash
Traceback (most recent call last):
  File "issue_gh_torch_105077.py", line 22, in <module>
    state_dict = torch.load(state_dict_file.name)
  File "/opt/pytorch/torch/serialization.py", line 1014, in load
    return _load(opened_zipfile,
  File "/opt/pytorch/torch/serialization.py", line 1422, in _load
    result = unpickler.load()
  File "/opt/pytorch/torch/_utils.py", line 205, in _rebuild_tensor_v2
    tensor = _rebuild_tensor(storage, storage_offset, size, stride)
  File "/opt/pytorch/torch/_utils.py", line 184, in _rebuild_tensor
    return t.set_(storage._untyped_storage, storage_offset, size, stride)
  File "/opt/pytorch/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1288, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1468, in dispatch
    self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
  File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1733, in invalidate_written_to_constants
    _, new_kwargs = normalize_function(
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 297, in normalize_function
    torch_op_schemas = get_signature_for_torch_op(target)
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in get_signature_for_torch_op
    signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in <listcomp>
    signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 70, in _torchscript_schema_to_signature
    arg_type = _torchscript_type_to_python_type(arg.type)
  File "/opt/pytorch/torch/fx/operator_schemas.py", line 64, in _torchscript_type_to_python_type
    return eval(ts_type.annotation_str, _type_eval_globals)
  File "<string>", line 1, in <module>
NameError: name 'Storage' is not defined
```

This PR adds the ability to create fake tensors during torch.load (when fake mode is active) by changing the storage's device to 'meta'.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108186
Approved by: https://github.com/ezyang, https://github.com/atalman
2024-02-16 23:42:50 +00:00
Sam Larsen
3e5e8590f4 Account for inference mode in FakeTensor cache (#119963)
Summary: an fbcode test exposed a shortcoming where we serve a FakeTensor from the cache with the wrong inference_mode. Take the current mode into account in the cache key so we only serve entries from the same mode we're in currently

Test Plan: New unit test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119963
Approved by: https://github.com/eellison
2024-02-16 02:53:33 +00:00