A longstanding confusion in the implementation of fake tensor and proxy tensor is what to do about torch.ops.aten.sym_sizes and related calls. In particular, when you have a tensor that (1) has symbolic shapes and (2) has a `__torch_dispatch__` call, previously, you would always get `__torch_dispatch__` calls for sizes/strides query, *even if you didn't request it* via the dispatch kwargs in `make_wrapper_subclass`.
The reason for this is because we were previously mixing several concepts: "I want to dispatch to Python", "I want to call a virtual method" and "I have dynamic shapes". A single boolean variable controlled all of these things, and so it was not possible to understand inside TensorImpl what the user had actually originally requested.
In this PR, we track each of these concepts individually so that we can preserve user intent. Then, we combine these into a single "policy" variable that controls whether or not we can use the fastpath or not. For the policy to trigger, we only need one of the exceptional cases to be true.
Billing of changes:
* Rename `set_sizes_strides_policy` to `set_custom_sizes_strides`; in general, you cannot DIRECTLY set policy; you have to indirectly set it by the public functions.
* Some helpers for sizes and strides, since it's more complicated (as it is an enum, rather than just bools as is the case for device and layout). `matches_python_custom` is used to test the Python dispatch user ask. `matches_policy` does the policy test (only used in the user facing functions.)
* I reorged the accessor methods so that they are more logical. This makes the diff bad, so I recommend reading the final code directly.
* The default custom implementations now more reliably call their default() implementations
* As bonus refactor, I devirtualized some functions that don't need to be virtual
* `set_sym_sizes_and_strides` is renamed to `set_sizes_and_strides` to make it easier to use in template contexts; it optionally takes a storage offset now so you can set all three values at the same time. If you use the SymInt overload but there are no symbolic integers, we give you a normal resize.
* This adds `sym_storage_offset` since we had that in the symbolic shapes branch and there's no reason not to put it in (and it reduces merge conflicts)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84641
Approved by: https://github.com/wconstab
I realized that we can deal with the dead vtable problem by...
introducing another indirection! The resulting code is worse
(you have to do one more dereference to get to the vtable), but
the reduction in boilerplate is, IMO, worth it.
I did this refactor because I'm about to add a lot more methods
to PyInterpreter to handle expunging SymInt from TensorImpl.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84388
Approved by: https://github.com/albanD
Add `TensorImpl::sym_strides`, bind it to python with `torch.ops.aten.sym_strides`, and use it in `ProxyTensor` and `FakeTensor`.
Before, `ProxyTensor` was generating `ProxySymInt`'s for the sizes, but not for the strides. Internally we still represent strides with a `SymIntArrayRef` though, so I ran into some weird issues where sizes were showing up as `ProxySymInt`, but strides were `PySymInt`'s.
Differential Revision: [D38594558](https://our.internmc.facebook.com/intern/diff/D38594558)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81300
Approved by: https://github.com/ezyang
This PR relands sym_numel #82374 and fixes the ios build break in this commit : 8cbd0031c5
which was a type mismatch in an equality.
### Description
<!-- What did you change and why was it needed? -->
### Issue
<!-- Link to Issue ticket or RFP -->
### Testing
<!-- How did you test your change? -->
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82731
Approved by: https://github.com/malfet
I noticed that in some situations torch dispatch modes were being
invoked with a mode active, which isn't supposed to happen (we
disable modes before calling into the user mode.) I also noticed that
I was getting a warning that I had a deprecated non-static definition of
torch dispatch on an argument even though there wasn't any.
It turns out this is because modes were part of the overloaded arguments
list in the Python fallback kernel for torch dispatch. This is wrong;
instead we should rely on the actual dispatching function to consult
modes. This makes the code simpler.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80992
Approved by: https://github.com/zou3519
The pattern of a PyObject* bundled with a PyInterpreter* is pretty
useful in many contexts (e.g., TorchDispatchTypeObject) so I have turned
it into a dedicated class SafePyObject. In the process I fixed a
bug with the old TorchDispatchTypeObject (copy constructor/assignment
was not deleted), made the API more safe (retrieving the PyObject*
pointer requires verification that the PyInterpreter* matches) and
fixed some minor inefficiencies in C++ code.
Signed-off-by: Edward Z. Yang <ezyangfb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75142
Approved by: https://github.com/zou3519