Make it so that it is valid to set metadata after detach calls, like `x.detach().resize_(...)`.
This technically lifts some restrictions around `.data`. This PR means that you can now technically call `x.data.resize_(...)`, which can now directly resize `x` instead of erroring.
My understanding: Before the tensor-variable merge, when `x` and `x.data` were really different tensors, you could resize `x.data` independently of `x`, and during the merge, this error was added to avoid silent confusing behavior changes.
It was agreed that this error has been around long enough (several years) that it's acceptable to drop. cc @albanD @ezyang.
(Ed already had a prototype PR [here](https://github.com/pytorch/pytorch/pull/83545) - I ended up making one to try to slog through test failures).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83590
Approved by: https://github.com/ezyang
I noticed I was missing tensor creations with modes when I tried
to delete proxy tensor. This was the cause.
Hypothetically, all PyInterpreter calls could get this treatment.
But I think it only matters for detach; the rest do not return
Tensors and most modes will not be interested in them.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83372
Approved by: https://github.com/zou3519
Fixes#81774
`TensorOptions` arguments in the JIT schema are optional, but in the Python API these were being translated to non-optional but with a default value. This change makes the arguments accept `None` for consistency with the JIT schema. However, it also means that `dtype=c10::nullopt` was previously completely untested so this also fixes several related bugs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82241
Approved by: https://github.com/ngimel
Currently we have 2 ways of doing the same thing for torch dispatch and function modes:
`with push_torch_dispatch_mode(X)` or `with X.push(...)`
is now the equivalent of doing
`with X()`
This removes the first API (which is older and private so we don't need to go through a deprecation cycle)
There is some risk here that this might land race with a PR that uses the old API but in general it seems like most are using the `with X()` API or `enable_torch_dispatch_mode(X())` which isn't getting removed.
EDIT: left the `with X.push(...)` API since there were ~3 land races with that over the past day or so. But made it give a warning and ask users to use the other API
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78215
Approved by: https://github.com/ezyang
This PR is doing a few interrelated things, all of which are necessary to get correctness. Read the comment in torch/fx/experimental/proxy_tensor.py for the high level overview.
Let's break down the parts of this PR:
* Bug fix where `enable_torch_dispatch_mode` with `None` doesn't work. This make `enable_torch_dispatch_mode(current_mode.inner)` work which is the basis for how we temporarily disable fake tensor mode.
* Bug fix for when fake tensor mode is combined with a non-mode tensor subclass. This actually could be ablated from this PR but it affects where the logic for allowing non fake tensor inputs with lift goes, so it's all in here in one go. There are some relevant tests for the fix in fake tensor, but it turns out I didn't need this because I'm always using proxy tensors as a mode (which ensures the ordering is right.)
* New `lift_fresh` view operator. Note that like lift, we have to manually write the functionalize kernel for these functions.
* The actual change, which is to save constants when we see them in the proxy tensor mode, and then propagate them as we go (because otherwise you'll handle mutations on constants incorrectly--see test.)
This is mildly BC-breaking if anyone was previously interposing on
at::lift, but this operator was relatively new and I checked
functorch which has no explicit reference to lift. So I think it
should not be too disruptive.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81192
Approved by: https://github.com/samdow, https://github.com/bdhirsh
I noticed that in some situations torch dispatch modes were being
invoked with a mode active, which isn't supposed to happen (we
disable modes before calling into the user mode.) I also noticed that
I was getting a warning that I had a deprecated non-static definition of
torch dispatch on an argument even though there wasn't any.
It turns out this is because modes were part of the overloaded arguments
list in the Python fallback kernel for torch dispatch. This is wrong;
instead we should rely on the actual dispatching function to consult
modes. This makes the code simpler.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80992
Approved by: https://github.com/zou3519
We don't have any coverage for meta tensor correctness for backwards
because torch function mode can only allow us to interpose on
Python torch API calls, but backwards invocations happen from C++.
To make this possible, I add torch_dispatch_meta test which runs the
tests with __torch_dispatch__
While doing this, I needed to generate fresh expected failure / skip
lists for the new test suite, and I discovered that my original
scaffolding for this purpose was woefully insufficient. So I rewrote
how the test framework worked, and at the same time rewrote the
__torch_function__ code to also use the new logic. Here's whats
new:
- Expected failure / skip is now done on a per function call basis,
rather than the entire test. This means that separate OpInfo
samples for a function don't affect each other.
- There are now only two lists: expect failure list (where the test
consistently fails on all runs) and skip list (where the test
sometimes passes and fails.
- We explicitly notate the dtype that failed. I considered detecting
when something failed on all dtypes, but this was complicated and
listing everything out seemed to be nice and simple. To keep the
dtypes short, I introduce a shorthand notation for dtypes.
- Conversion to meta tensors is factored into its own class
MetaConverter
- To regenerate the expected failure / skip lists, just run with
PYTORCH_COLLECT_EXPECT and filter on a specific test type
(test_meta or test_dispatch_meta) for whichever you want to update.
Other misc fixes:
- Fix max_pool1d to work with BFloat16 in all circumstances, by making
it dispatch and then fixing a minor compile error (constexpr doesn't
work with BFloat16)
- Add resolve_name for turning random torch API functions into string
names
- Add push classmethod to the Mode classes, so that you can more easily
push a mode onto the mode stack
- Add some more skips for missing LAPACK
- Added an API to let you query if there's already a registration for
a function, added a test to check that we register_meta for all
decompositions (except detach, that decomp is wrong lol), and then
update all the necessary sites to make the test pass.
Signed-off-by: Edward Z. Yang <ezyangfb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77477
Approved by: https://github.com/zou3519
Double-header bug fix:
- As reported by jansel, dtypes are still showing up as integers
when the schema is an optional dtype. This is simple enough to
fix and I added a test for it. But while I was at it...
- I noticed that the THPMemoryFormat_new idiom with "unused" name
doesn't actually work, the repr of the returned memory format
object is wrong and this shows up when we try to log the args/kwargs.
So I fixed memory format to do it properly along with everything
else.
Fixes https://github.com/pytorch/pytorch/issues/77135
Signed-off-by: Edward Z. Yang <ezyangfb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77543
Approved by: https://github.com/albanD, https://github.com/jansel
You can now do a lot of crazy things about redefining the behavior of an operator, and still be fast in cuda !!!
Example 1: swapping where's branches
```
code_string = "template <typename T> T inverted_where(bool cond, T a, T b){ return !cond ? a : b; }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::where.self', jitted_fn, "CUDA")
# torch.where is now overridden
```
Example 2: approximate gelu with relu
```
code_string = "template <typename T> T fast_gelu(T a){ return a > 0 ? a : 0;}"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', jitted_fn, "CUDA")
# torch.nn.GELU and torch.nn.function.gelu are now overridden
```
Example 3: clipping output for numerical unstable kernels
```
code_string = "template <typename T> T clipped_exp(T a){ return a > T(10.0) ? T(22026.4657948) : exp(a); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::exp', jitted_fn, "CUDA")
# torch.exp(x) and x.exp() are now overridden
```
Example 4: Simulate buggy hardware behaviors
```
code_string = "template <typename T> T buggy_add(T a, T b){ return a + b + T(1); }"
jitted_fn = torch.cuda.jiterator._create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::add.Tensor', jitted_fn, "CUDA")
torch.add(x, y), "x + y" and x.add(y) are now overridden
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77121
Approved by: https://github.com/anjali411
For the most part, PrimTorch refs have the same signature as their
ATen equivalents. I modify most PrimTorch refs to register themselves
as decompositions, using the prim name they wrap to find the aten name
(except for a few cases where the prim/aten names mismatch). There are
some exclusions, falling into one of two categories:
- The torch equivalent was already implemented as a CompositeImplicitAutograd
decomposition in C++
- The ref doesn't support enough features (e.g., the real deal has more
kwargs / overloads than are currently implemented)
PrimTorch refs are written as a single function that supports all
overloads, and this style is convenient for cases where we have a bundle
of overloads for what morally is a single overload with a Union type
on an argument (which we ought to have supported in
native_functions.yaml but blah); to support registering a single decomp
for all the overloads, we modify register_decomposition to register
to ALL overloads if you pass it an overload packet. This is technically
BC breaking but no tests started failing because of it.
Signed-off-by: Edward Z. Yang <ezyangfb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76835
Approved by: https://github.com/Chillee, https://github.com/mruberry