More details further down, but first a more high-level description of "how do we functionalize storage resizing"
Today, dynamo converts `param.untyped_storage().resize_(x)` calls that it sees from fsdp into a custom op, `ops.inductor.resize_storage_bytes_(x)`
So given this setup, there are 3 main cases that I think we want to handle:
(1) graph input starts with a real storage size, gets resized down to zero in the graph
(2) graph input starts with 0 storage size, gets resized up in the graph
(3) graph input starts with 0 storage size, gets resized up and used in some compute, then resized back down to 0
For case (1) we need to emit a `resize_storage_bytes_` at the end of the graph, similar to how we emit `copy_()` for data mutations.
For case (2), we need to emit a `resize_storage_bytes_` in the graph, and we **also** need to emit a `copy_()` (the input had its storage resized up, and filled in with data, which is we need to reflect as an input mutation)
For case (3), the net effect is that the input had no data on entry and exit of the function, so we don't need to emit any mutable ops in the end of the graph.
The main thing to call out is that: we need to write a functionalization rule for `resize_storage_byte_`, (`FunctionalTensorWrapper::storage_resize_()`) and this rule actually does very little. We would like to **not** emit any new ops in the graph (like say, a functional resize op). Instead, we should expect / rely on the fact that any resize up will be immediately followed by a `copy_()`/`foreach_copy_`/`out=` op, that will fill in the data of the tensor. So `FunctionalTensor` can temporarily live in a state where its data is invalid, until the `x.copy_(y)` "updates" its data with the new tensor.
So effectively, all that this rule does is:
(1) it stores metadata on the storage, indicating that the tensor was resized, as well as the updated storage size. We need this info in AOTAutograd, so it knows whether to emit a mutable resize_() op in the graph epilogue
(2) There is also a corner case: if we are resizing down to zero, but our tensor had **previously** had a zero size storage, then we update `value_` to point to the original value of the tensor. The reason this seems safe is because if we have a zero storage sized tensor `x`, and we resize it up, use it in some compute, resize it back down to zero, and use it somewhere, we would want the functional version of this code to use the original `x` after the second resize. For FSDP, this is important because we end up saving parameters (graph inputs) for backward, and we want to make sure that the thing we save (and the output to the forward graph) is the original, zero-storage-sized parameter, and not the "version 2" of the parameter after the first resize_()
I think a good order to look at changes in this PR would be:
(1) `test_aotdispatch.py` shows the 3 main cases I focused on as well as the expected functionalized graphs
(2) In `FunctionalStorageImpl.h/cpp`, I had to add a notion of "original base", and "original/curr_size". The first is so I can re-use the zero-size tensor after multiple resizes, and the second is so I can tell in AOTAutograd whether any resizes canceled each other out into a no-op
(3) FunctionalTensorWrapper.h/cpp has the new resize functionalizion rule + some extra utils
(4) `_functorch/_autograd`: the main changes in this folder were around adding the logic at trace-time to detect when we need to put a resize_() in the graph. I also have some assertions to check that any inputs that experience storage resizing will **always be in the graph** and not the opaque epilogue, and I also limited the resize_() mutation case so that you can only ever start with zero storage, or end with zero storage (you can't do e.g. `torch.ones(2).storage().resize_(3)`), and banned it on tensor subclasses
(5) `fake_tensor.py`/`meta_utils.py`: we now need to be able to fakeify tensors with zero storage, so I added a quick version of it in meta_utils.py. This also.. has ramifications for fake tensor caching that I need to fix (include the storage size on the cache key, maybe?)
------------------
This PR subsumes https://github.com/pytorch/pytorch/pull/120971.
This PR is enough to **almost** get a simple ppFSDP forward pass tracing with a functionalized resize_() properly. It also attempts to do the updated version from @jansel, where we don't have any notion of `resize_()` in the graph at all, post functionalization. It would probably be good to test it with @yf225 's FSDP changes, and see how many of the FX passes it allows us to remove. I think that in theory, it should allow us to remove all FX passes that affect the forward graph / partitioner, **except** the one that forces views to be recomputed in the backward (more details below).
There are a few things worth calling out:
(1) failed attempt at functionalizing `aten.copy_()`. I originally wanted to get a version takes these operations:
```
param.storage().resize_(all_gather_size)
param.copy_(all_gather_buffer)
out = aten.matmul(param, param)
```
and functionalizes them into:
```
out = aten.matmul(all_gather_buffer, all_gather_buffer)
```
This would involve getting functionalization to turn `x.copy_(y)` into a giant no-op that just returns `y`. Unfortunately, we can't actually do this in a reasonable way within functionalization (instead, there's a functional `aten.copy` in the graph - see the test case graph expecttest for details). Why? In order for that transformation to be safe, `x` and `y` need to have the same metadata. However, it's possible for `x` and `y` to be subclasses of different types. This is not something we can easily tell from within functionalization, and would be a layering violation. So for now I'm leaving it to downstream code to optimize away the `aten.copy` (this is already the case today, so I think inductor can handle this)
(2) The forward doesn't **actually** run successfully in this PR (see the `assertRaisesRegex` in the test). Why?
The final forward graph looks like this:
```
def forward(self, primals_1, primals_2):
_foreach_copy = torch.ops.aten._foreach_copy.default([primals_1], [primals_2]); primals_2 = None
getitem = _foreach_copy[0]; _foreach_copy = None
mm = torch.ops.aten.mm.default(getitem, getitem); getitem = None
t_1 = torch.ops.aten.t.default(primals_1); primals_1 = None
return [mm, t_1]
```
Where `primals_1` starts out as a secretly-zero-storage-size parameter, and gets resized up and back down within the forward (these are functionalized away).
Importantly, the matmul happy on the result of the `foreach_copy`, **but** the activation that we save for backward (`t_1`) is the result of transposing the **original parameter** (the zero-storage-size param). This is exactly the optimization in fsdp that allows us to have good peak memory usage.
The problem is that the min-cut partitioner decides to save `t_1` for backward. Running this code in eager breaks, because the kernel for `aten.permute(x)` is not happy when `x` has secretly-zero-sized-storage.
The real problem here is that in eager mode the `permute` kernel runs during the backward, after backward hooks have properly resized the saved activation. Here, we are running the transpose in the forward.
One option would be to turn off the checks in our view kernels and allow them to work on zero-storage-sized tensors, which feels pretty bad. Another option is to tweak the partitioner (or use one of Will's FX passes) to force the partitioner to not save views for backward, and allow the views to be recomputed in the backward. This seems kind of silly, but is also probably harmless.
(3) The backward is still broken. To be fair, this issue is pretty separable from "functionalizing storage resize calls", and can be fixed later (either by a real fix to our tracing infra, or via another hacky FX pass). More description of this problem is described at issue (8) of my PR description in https://github.com/pytorch/pytorch/pull/120971
(4) I only added support for "full graph" resizing: basically, the limited case where a param starts with zero storage size, and gets resized up and back down. I think we can add support for the graph break case, but I think we can keep that add-on separate from this PR unless we need it immediately. I also added asserts so we should fail loudly when we hit this case
(5) I have a change to FakeTensor creation when inputs have zero storage size that.. is probably ok. But I also removed FakeTensor caching on view ops, which I probably need to fix before I can land this PR
(6) I added a notion of "original_base" to `FunctionalStorageImpl`. More details are in the comments, but my rational for this was that we basically need it to ensure that autograd saves the **original**, zero-storage-sized param for backward, after resizing up and back down
(7) I had to update our eager kernels for `aten.copy` and `aten._foreach_copy`, to handle the case where the `self` argument has secretly-zero-storage. Inductor can probably generate correct code for this case, but we need these ops to work properly in this situation for the `aot_eager` backend to do the right thing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122434
Approved by: https://github.com/jansel
When dispatching a fake tensor op we cache the result with `(op, args)` as the key. There are some args (such as one with a dynamic output shape) where the output can't be cached. Instead of validating the args every time we compute the cache only validate the args when we first see a new cache key.
18.3% FakeTensor perf win on the microbenchmark (21.7% cumulative)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124223
Approved by: https://github.com/oulgen, https://github.com/masnesral
ghstack dependencies: #122911
A common complaint when working with data-dependent code in PyTorch is that it's hard to tell how far you are from the finish line: every time a GuardOnDataDependentSymNode error is hit, you have to somehow fix or workaround it to see the next one.
This PR adds a new mode `torch._functorch.config.fake_tensor_propagate_real_tensors` which modifies fake tensors to also propagate real tensors. This means that when we try to guard on a data-dependent SymNode, we can actually produce a real result. We also produce a warning which you should consult to figure out what the crux points are.
I ran this on vision_maskrcnn. In the baseline (without this mode), the model has 27 graph breaks, resulting in 40 graphs. With this mode on, the model has only 11 graph breaks, resulting in 15 graphs (the remaining graph breaks are due to missing functionality for item() on float tensor and some other Dynamo missing features.) You get a list of things that would have errored like this:
```
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u0), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u1) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u1), 1)) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Ne(Max(1, u1), 1)) -> True
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Max(1, u0) < 2) -> False
WARNING:torch.fx.experimental.symbolic_shapes:propagate_real_tensors evaluate_expr(Eq(Max(1, u0), 1)) -> False
```
Potential later follow ups:
* Improve the warning messages (in particular, should provide user frames)
* GC real tensors when they are no longer needed by tracing. Right now, this will use A LOT of memory, equal to as if your GC was broken and every intermediate tensor was kept live
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125115
Approved by: https://github.com/IvanKobzarev
To fix data-dependent errors we want to recommend that people use `torch._check*` APIs. The `constrain_as*` APIs should be fully subsumed by them, and in the future we should kill them entirely.
Differential Revision: D56774333
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125253
Approved by: https://github.com/ezyang
fake_tensor.py had mypy error ignored. That seems less than desirable.
Also added SafePyObjectT<T> which is a tagged wrapper around a SafePyObject but provides static type checking (with no other guarantees).
Used `SafePyObjectT<TorchDispatchModeKey>` on some of the TorchDispatchModeTLS API to ensure that we don't accidentally inject a different type than expected into the stack.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124428
Approved by: https://github.com/malfet
fake_tensor.py had mypy error ignored. That seems less than desirable.
Also added SafePyObjectT<T> which is a tagged wrapper around a SafePyObject but provides static type checking (with no other guarantees).
Used `SafePyObjectT<TorchDispatchModeKey>` on some of the TorchDispatchModeTLS API to ensure that we don't accidentally inject a different type than expected into the stack.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124428
Approved by: https://github.com/malfet
This PR has a lot of "draw the rest of the fucking owl" energy. Here's how to break it down.
1. **torch/_inductor/graph.py** - We start by tightening unbacked symbol invariants. Specifically, as we lower FX nodes, we check whether or not every unbacked_binding recorded on the FX node meta, actually ends up getting bound (according to get_unbacked_symbol_defs) in all the buffers generated by the lowering. Hopefully this invariant is self evident. This leads to a lot of failures.
2. **torch/_inductor/ir.py** - Problem 1: There is softness in how Inductor computes defs of unbacked symbols in IR node. Previously, we tried to infer it by looking at the output sizes/strides/etc and see if new unbacked symbols popped up that we hadn't seen in the inputs. I don't know exactly what was buggy about the old code, but sometimes we would fail to notice an unbacked symbol had been bound, or rebind an unbacked symbol multiple times. Fortunately, thanks to the earlier PRs in our stack, we now have a nice list of unbacked symbol bindings from FX, so we now just store it directly on ExternKernel and use it directly to report defs. This has to be done twice: once for FallbackKernel (e.g., nonzero) and once for DynamicScalar (e.g., item) (see also **torch/_inductor/lowering.py**, **torch/_inductor/codegen/wrapper.py** and **torch/_inductor/codegen/cpp_wrapper_cpu.py** for the lowering and codegen changes for item)
* **process_kernel** - Sidequest! It turns out that Inductor lowering can reallocate unbacked symbols. This happens specifically when we repropagate fake tensors through the operator in `process_kernel`. This repropagation process is necessary because Inductor may have changed the strides of input tensors, and it must now recompute the strides so that it can continue to appropriately plan the rest of the lowering process. This is fine: we just make sure we do the rebind unbacked + compute_unbacked_bindings dance we've been doing previously in the PR stack. But instead of putting unbacked_bindings on a new FX node, they go straight into our unbacked_bindings on the Inductor IR node.
* **codegen_unbacked_symbol_defs** - Sidequest! FallbackKernel lowering is done in two steps. First, you emit the FallbackKernel buffer. Then, you emit MultiOutput buffers which actually give access to the individual outputs of FallbackKernel, which may have been multi-output. There is a design decision here: does the FallbackKernel bind the unbacked symbols, or the MultiOutput buffer? Historically, we put the binding on MultiOutput buffer, because it's more convenient: the FallbackKernel buffer is fake, in fact, it doesn't even get a name in C++ codegen. But it's kind of inconsistent with the keypath model that we've been tracking unbacked bindings with: if you have a multi-output node, you'd expect a keypath like `[0].size()[0]` representing the first output's first dimension size. That suggests that it's the FallbackKernel that should define the things. So that was my first implementation. Unfortunately, the C++ codegen is too cursed and I could not understand how to make it work in that case. So now we just unsoundly assume you cannot have multi-output data dependent output, and do the codegen in MultiOutput. There are some comments explaining exactly what we are improperly assuming.
3. **_rename_unbacked_to** in **torch/fx/experimental/symbolic_shapes.py** - Previously, when we renamed unbacked symbols, we clobbered any facts we previously knew about them. So for example, if we had a replacement `u0 -> s0` but then we renamed u0 to u1, we would now setup the replacement `u0 -> u1`, clobbering the old replacement. This apparently didn't matter in earlier PRs in the stack, but with Inductor now on the ball, there were some tests that indicated this was a problem. The solution is easy: if u0 had a preexisting replacement, reapply it to u1. However...
* **torch/_functorch/_aot_autograd/collect_metadata_analysis.py** - When we run forward analysis, this triggers fake tensor repropagation and fresh allocations. Previously, we just cleared out the pending symbols when finished the analysis. But with the change above, this would also migrate replacements to the new symbols... which are now dead. So now we explicitly suppress generation of these symbols with `ignore_fresh_unbacked_symbols` so that no rebinding happens at all.
* **torch/_dynamo/eval_frame.py** - same deal; I just searched for all sites we called clear() on pending
4. The last step is fixing the long tail of extra problems that show up, now that unbacked_bindings are load bearing into Inductor
* **torch/_dynamo/eval_frame.py** - Some of the exports are making copies of nodes without repropagating fake tensors, so in this case, it is important to also copy the `unbacked_bindings` (apparently this didn't matter before without the Inductor changes)
* **torch/_export/pass_base.py** - I discover that this is doing fake tensor repropagation via a test suite failure. Do the same playbook as AOTAutograd: PropagateUnbackedSymInts too! Actually, they also have implemented their own tracer as well, so do the same playbook as proxy_tensor: record unbacked_bindings on the newly traced nodes. UGH code duplication.
* **torch/_subclasses/fake_tensor.py**, **torch/_subclasses/fake_impls.py** (with call site updates at **torch/_functorch/_aot_autograd/traced_function_transforms.py** and **torch/fx/passes/fake_tensor_prop.py**) - What's this new epoch thing? I noticed that sometimes I would be retracing, call nonzero() on a fake tensor, and not allocate a new unbacked symbol. This is actually bad, because if I don't get a new unbacked symbol, I don't know there's a binding site, and `unbacked_bindings` is now missing a binding. The reason for this is memoization: if I reuse the exact same fake tensor on my retrace, it will already have an unbacked symint memoized on it and we will short circuit allocation. Well, that's no good. So I associate the memos with a fake tensor epoch, and every time you start a new fake tensor propagation from scratch, you bump the epoch so that I clear all the memos.
* **torch/_inductor/scheduler.py** - I notice in unit tests that V.current_node is not always set when we call process_kernel. So I save it into the IR node and restore it when we are running `get_estimated_runtime`.
* **torch/fx/experimental/symbolic_shapes.py** - A few things
* **rebind_unbacked** (re **_tensor_version**). Ordinarily, when you have an unbacked SymInt, you persistently hvae it all the way to the end of the program. `_tensor_version` violates this: this generates an unbacked SymInt (for reasons I don't quite understand?) and then gets rid of it later. This triggered an assert violation. I think this op is kind of misusing unbacked SymInt, but I didn't know how to refactor it, so it gets a special case.
* **rebind_unbacked** (re **Simplify SymBool binding**). Ugh, SymBool, what a pain in the butt. I have an assert that you can only rebind unbacked symbol to another unbacked symbol. This assert fails when a boolean is involved, because the result of running keypath on the result is not `u1`, it's `sympy.Piecewise(... sympy.Eq(u1, 1) ...)`. This is actually just `u1`, but Sympy doesn't know it because it doesn't know that `u1` value range is `[0, 1]`. So we manually implement the simplification needed to get the assert to pass.
* **compute_unbacked_bindings** (re **This is pretty fragile**). There is a really funny disaster involving memoization and Inductor process kernel. Ordinarily when I retrace, if there was a memo hit in the old trace, there will be a memo hit in the new trace. However, Inductor process kernel breaks this, because it recreates fake tensor inputs to the operator call from scratch (since they might have different strides), and obviously these tensor inputs don't have the memo from the old one. I tried a little bit to try to manually transplant the memo to the new fake tensor but it seemed hopeless, so I just let the fresh symbol ride, allocating a new unbacked symbol. However, in one of our tests, we rely on knowing that the first nonzero call is equal to the second (memoized) nonzero call. The equality test looked pretty easy to discharge, so I just went ahead and added a deferred runtime assert to this effect and it worked.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124394
Approved by: https://github.com/jansel
ghstack dependencies: #124310, #124314, #124316
Adds a ruff lint rule to ban raising raw exceptions. Most of these should at the very least be runtime exception, value errors, type errors or some other errors. There are hundreds of instance of these bad exception types already in the codebase, so I have noqa'd most of them. Hopefully this error code will get commiters to rethink what exception type they should raise when they submit a PR.
I also encourage people to gradually go and fix all the existing noqas that have been added so they can be removed overtime and our exception typing can be improved.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124570
Approved by: https://github.com/ezyang
Also partially fixes#122109
This PR:
- We add a C++ flag (only_lift_cpu_tensors) to toggle the
torch.tensor(1, device='cuda') ctor strategy.
When false (default), it does the current PyTorch behavior
of unconditionally constructing a concrete CUDA tensor then calling
lift_fresh on it. When true, we instead construct a concrete CPU
tensor, call lift_fresh, and then call Tensor.to(device) (under any ambient
modes).
- FakeTensorMode flips this flag depending on if CUDA is available or
not. We don't unconditionally set the flag to True because that is
likely BC-breaking.
Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124413
Approved by: https://github.com/eellison
If we throw an exception in the "wrong" place we can end up with the dispatch state being in a weird state which can cause all future dispatching to fail. Preserve and restore it as part of `preserve_global_state` so we know it's sane after that.
Also fake_tensor's in_kernel_invocation_manager() was leaving a bit set in the dispatcher (DispatchKey.Dense) which affected follow-on code. Fixed that to reset after as well.
Repro:
before:
```
$ rm test/dynamo_skips/TestSparseCPU.test_to_dense_with_gradcheck_sparse_cpu_complex64
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest -s test/dynamo/test_export.py test/test_sparse.py -k 'test_to_dense_with_gradcheck_sparse_cpu_complex64'
======== 1 passed, 6173 deselected in 5.21s =============
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest -s test/dynamo/test_export.py test/test_sparse.py -k 'test_torch_inference_mode_ctx or test_to_dense_with_gradcheck_sparse_cpu_complex64'
========= 1 skipped, 6172 deselected, 1 error in 5.29s =========
```
(note that test_to_dense_with_gradcheck_sparse_cpu_complex64 passes on its own but failed when including the skipped test_export.py tests)
after:
```
$ rm test/dynamo_skips/TestSparseCPU.test_to_dense_with_gradcheck_sparse_cpu_complex64
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest -s test/dynamo/test_export.py test/test_sparse.py -k 'test_to_dense_with_gradcheck_sparse_cpu_complex64'
===================== 1 passed, 6173 deselected in 5.42s =====================
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest -s test/dynamo/test_export.py test/test_sparse.py -k 'test_torch_inference_mode_ctx or test_to_dense_with_gradcheck_sparse_cpu_complex64'
===================== 1 passed, 1 skipped, 6172 deselected in 7.30s ======================
```
(note that test_to_dense_with_gradcheck_sparse_cpu_complex64 passes in both runs)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122073
Approved by: https://github.com/zou3519
Today, we error out on FakeTensor.data_ptr under torch.compile. This PR
moves to error out on FakeTensor.data_ptr under eager mode to avoid
diverging behavior.
We do this by adding another bit onto FakeTensor that we'll remove after
the deprecation cycle.
Test Plan:
- tested locally
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123292
Approved by: https://github.com/eellison
ghstack dependencies: #123261, #123282, #123291
This PR only adds abstract class registration logic without touching existing tests so they still trace with real script object. The added tests are only for registration APIs and test error messages.
Our design is that the abstract implementation should be in Python. This is much better in terms of usability. But this also has implications for custom op that takes script object as input, which is detailed later in this stack.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122622
Approved by: https://github.com/zou3519
ghstack dependencies: #122619, #122620, #122621
This PR:
- disallows FakeTensor.data_ptr when it is called inside PT2 or fx tracing.
- disallows FunctionalTensor.data_ptr (python FunctionalTensor is only used in
PT2)
The motivation behind this is that the leading cause of segfaults when
using custom ops with PT2 is calling .data_ptr on FunctionalTensor or
FakeTensor.
This change is BC-breaking. If your code broke as a result of this, it's
because there was a bug in it (these .data_ptr should never be
accessed!). You can either fix the bug (recommended) or get the previous
behavior back with:
```
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor
data_ptr = 0 if isinstance(tensor, (FakeTensor, FunctionalTensor)) else tensor.data_ptr()
```
Test Plan:
- existing tests
Differential Revision: [D55366199](https://our.internmc.facebook.com/intern/diff/D55366199)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122514
Approved by: https://github.com/ezyang, https://github.com/albanD, https://github.com/yifuwang, https://github.com/kurtamohler
It's annoying grepping for `__call__` call-sites so they're now all explicit now. I'd do this to MetaConverter too but that one is way more public, a lot more sites.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122270
Approved by: https://github.com/eellison
ghstack dependencies: #122044
Fixes https://github.com/pytorch/pytorch/issues/121085
This PR pretty involved so pay attention to this description. At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.
However, this transformation is not always entirely mechanical. Here
is what you need to pay attention to:
- The memo table from real Tensor -> meta/fake Tensor is now broken
into two memo tables: real Tensor -> stable int id -> meta/fake
Tensor. The stable int id is needed so that when we do serialization,
we know when tensors/storages alias each other and can ensure we preserve
this aliasing upon deserialization.
The way I have implemented changes the weak reference behavior.
Previously, when either the real Tensor OR the meta/fake Tensor went
dead, we would remove the entry from the memo table. Now, this only
removes entries from one of the two memo tables. This semantically
makes sense, because the user may have held on to the stable int id
out of band, and may expect a real Tensor to continue to be numbered
consistently / expect to be able to lookup a meta/fake tensor from
this id. If this is unacceptable, it may be possible to rejigger
the memo tables so that we have real Tensor -> stable int id
and real Tensor -> meta/fake Tensor, but TBH I find the new
implementation a lot simpler, and arranging the memo tables in this
way means that I have to muck around with the real tensor to save
to the memo table; in the current implementation, I never pass the
Tensor to meta_tensor function AT ALL, which means it is impossible
to accidentally depend on it.
- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
to be careful not to poke fields when they are not valid. Previously,
preconditions were implicitly checked via the conditional structure
("is this sparse? is this nested?") that is tested before we start
reading attributes. This structure has to be replicated in
describe_tensor, and I have almost assuredly gotten it wrong on my
first try (I'll be grinding through it on CI; a careful audit will
help too, by auditing that I've tested all the same conditionals that
the original access was guarded by.)
- I originally submitted https://github.com/pytorch/pytorch/pull/121821
for the symbolic shapes change, but it turned out the way I did it
there didn't actually work so well for this PR. I ended up just
inlining the symbolic shapes allocation logic into MetaConverter
(look for calls to maybe_specialize_sym_int_with_hint), maybe there
is a better way to structure it, but what I really want is to
just read sizes/strides/offset directly off of MetaTensorDesc; I
don't want another intermediate data structure.
- Some fields aren't serializable. These are documented as "NOT
serializable". ctx/type should morally be serializable and I just
need to setup a contract with subclasses to let them be serialized.
The fake_mode is used solely to test if we are refakefying with
a pre-existing ShapeEnv and we want to reuse the SymInt
directly--serializing this case is hopeless but I am kind of hoping
after this refactor we do not need this at all. view_func is not
serializable because it's a bound C implemented method. Joel has
promised me that this is not too difficult to actually expose as a
true data structure, but this is the edgiest of edge cases and there
is no reason to deal with it right now.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122044
Approved by: https://github.com/eellison
Fixes https://github.com/pytorch/pytorch/issues/121085
This PR pretty involved so pay attention to this description. At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.
However, this transformation is not always entirely mechanical. Here
is what you need to pay attention to:
- The memo table from real Tensor -> meta/fake Tensor is now broken
into two memo tables: real Tensor -> stable int id -> meta/fake
Tensor. The stable int id is needed so that when we do serialization,
we know when tensors/storages alias each other and can ensure we preserve
this aliasing upon deserialization.
The way I have implemented changes the weak reference behavior.
Previously, when either the real Tensor OR the meta/fake Tensor went
dead, we would remove the entry from the memo table. Now, this only
removes entries from one of the two memo tables. This semantically
makes sense, because the user may have held on to the stable int id
out of band, and may expect a real Tensor to continue to be numbered
consistently / expect to be able to lookup a meta/fake tensor from
this id. If this is unacceptable, it may be possible to rejigger
the memo tables so that we have real Tensor -> stable int id
and real Tensor -> meta/fake Tensor, but TBH I find the new
implementation a lot simpler, and arranging the memo tables in this
way means that I have to muck around with the real tensor to save
to the memo table; in the current implementation, I never pass the
Tensor to meta_tensor function AT ALL, which means it is impossible
to accidentally depend on it.
- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
to be careful not to poke fields when they are not valid. Previously,
preconditions were implicitly checked via the conditional structure
("is this sparse? is this nested?") that is tested before we start
reading attributes. This structure has to be replicated in
describe_tensor, and I have almost assuredly gotten it wrong on my
first try (I'll be grinding through it on CI; a careful audit will
help too, by auditing that I've tested all the same conditionals that
the original access was guarded by.)
- I originally submitted https://github.com/pytorch/pytorch/pull/121821
for the symbolic shapes change, but it turned out the way I did it
there didn't actually work so well for this PR. I ended up just
inlining the symbolic shapes allocation logic into MetaConverter
(look for calls to maybe_specialize_sym_int_with_hint), maybe there
is a better way to structure it, but what I really want is to
just read sizes/strides/offset directly off of MetaTensorDesc; I
don't want another intermediate data structure.
- Some fields aren't serializable. These are documented as "NOT
serializable". ctx/type should morally be serializable and I just
need to setup a contract with subclasses to let them be serialized.
The fake_mode is used solely to test if we are refakefying with
a pre-existing ShapeEnv and we want to reuse the SymInt
directly--serializing this case is hopeless but I am kind of hoping
after this refactor we do not need this at all. view_func is not
serializable because it's a bound C implemented method. Joel has
promised me that this is not too difficult to actually expose as a
true data structure, but this is the edgiest of edge cases and there
is no reason to deal with it right now.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122044
Approved by: https://github.com/eellison
ghstack dependencies: #122018
Summary:
with a simple bench in TestDeserializer.test_basic function:
```
time_start = time.time()
for i in range(1000):
self.check_graph(MyModule(), inputs)
warnings.warn(f"time_taken: {time.time() - time_start}")
```
and forcing FakeTensorConfig.debug to True, record_stack_traces to True, logging level to debug, it shows that the the changed code is consistently ard 20 secs faster (~90s vs originally ~110s)
Test Plan:
test passed, see summary
compared debug trace before and after:
- exactly the same for fake tensor and proxy callsite https://www.internalfb.com/intern/diffing/?paste_number=1189883685
- slightly different for the user frame in proxy node https://www.internalfb.com/intern/diffing/?paste_number=1189884347
Differential Revision: D54237017
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121449
Approved by: https://github.com/angelayi
This does not introduce a new test but is tested by checking that all the classes we already have still behave as before now that they don't explicitly disable torch_function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120632
Approved by: https://github.com/ezyang
`Tensor.__repr__` calls functions which can perform logging which ends up logging `self` (with `__repr__`) causing an infinite loop. Instead of logging all the args in FakeTensor.dispatch log the actual parameters (and use `id` to log the tensor itself).
The change to torch/testing/_internal/common_utils.py came up during testing - in some ways of running the test parts was `('test', 'test_testing.py')` and so `i` was 0 and we were doing a join on `()` which was causing an error.
Repro:
```
import torch
from torch.testing import make_tensor
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
t = torch.sparse_coo_tensor(((0, 1), (1, 0)), (1, 2), size=(2, 2))
t2 = FakeTensor.from_tensor(t, FakeTensorMode())
print(repr(t2))
```
and run with `TORCH_LOGS=+all`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120206
Approved by: https://github.com/yanboliang, https://github.com/pearu
Summary: an fbcode test exposed a shortcoming where we serve a FakeTensor from the cache with the wrong inference_mode. Take the current mode into account in the cache key so we only serve entries from the same mode we're in currently
Test Plan: New unit test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119963
Approved by: https://github.com/eellison