This should fix a few of the errors I was seeing when I turned on functionalization in torchbench. It also fixes this AOTAutograd repro with resnet18:
```
import torch
from torchvision.models import resnet18
from functorch._src.compilers import nop
from functorch._src.aot_autograd import aot_module
from functorch.compile import config
config.use_functionalize = True
model = resnet18().cuda().half().to(memory_format=torch.channels_last)
input = torch.randn(256, 3, 224, 224, device='cuda', dtype=torch.float16) \
.to(memory_format=torch.channels_last).detach().requires_grad_(True)
input_expected = input.clone().detach().requires_grad_(True)
fn = aot_module(model, nop)
out = fn(input)
out_expected = model(input_expected)
print(torch.allclose(out, out_expected))
out.sum().backward()
out_expected.sum().backward()
print(torch.allclose(input.grad, input_expected.grad))
```
The problem was that functorch adds a decomp to the decomp table for `new_zeros`:
```
@register_decomposition(aten.new_zeros, aot_autograd_decompositions)
def new_zeros(inp, size, dtype=None, layout=None, device=None, pin_memory=None):
return torch.zeros(size, dtype=inp.dtype, device=inp.device)
```
When calling that decomp from inside of `ProxyTensorDispatchMode`, the ProxyTensorMode is already disabled, and `torch.zeros` doesn't take in any tensor-like arguments, so we never end up dispatching back into python again.
The way that manifests is that the output of `new_zeros()` gets baked as a constant into the AOTAutograd FX graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83207
Approved by: https://github.com/ezyang
This allows you to directly call into the CompositeImplicitAutograd
implementation of an operator, *without* changing any aspects of the
dispatcher state. In particular, you can use this to recursively call
into a decomposition, dispatching back to your tensor subclass/mode
as desired.
Hypothetically, we should also make these available in the
decompositions dictionary, but I'm leaving this as future work as
enumerating these decompositions is annoying (as operators are lazily
registered.)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83075
Approved by: https://github.com/albanD
Reference: #69991
This adds a new CompositeExplicit Operator `at::assert_all_true` (also exposed in Python) to check the truthiness of a tensor and throw an error based on that.
This helps us mitigate `TORCH_CHECK(t.all().item<bool>(), "err_msg")` pattern which is not composite compliant.
Using the mentioned operator, we fix `quantile` and `nanquantile` to be Composite Compliant.
Question: Should it be documented?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81767
Approved by: https://github.com/zou3519
This gives us tests for fake (all passing) and symbolic (not all passing).
I needed to add a gadget for xfail'ing tests in symbolic. It might be
generally useful, let me know what you think.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82746
Approved by: https://github.com/eellison
I intend to run all of the proxy tensor tests on each of our tracing
modes, but to do this I have to make the tests parametrized on
tracing mode first. This does that refactor, without adding any
new tests.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82739
Approved by: https://github.com/eellison
This PR relands sym_numel #82374 and fixes the ios build break in this commit : 8cbd0031c5
which was a type mismatch in an equality.
### Description
<!-- What did you change and why was it needed? -->
### Issue
<!-- Link to Issue ticket or RFP -->
### Testing
<!-- How did you test your change? -->
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82731
Approved by: https://github.com/malfet
Add support for sparse fake tensors.
- The testing strategy is to run a fake tensor cross ref test on `test_sparse.py`. This is necessary because OpInfo sparse coverage is completely nonexistent. We could have tried to turn on cross ref testing globally for all files, but that would be very time consuming and the tests I'm interested in are mostly in this file. There are some exclusions in testing for things that don't work.
- I make fake tensor converter raise a UnsupportedFakeTensorException if the meta converter fails to do a conversion (which can happen in a relatively large number of situations).
- I relax fake tensor invariants so that you can make a fake tensor from a meta tensor. This is useful because in the cross ref test sometimes we operate on meta tensors.
- Fake tensor wrapping is improved to handle the case when a function doesn't return any tensors
- Meta converter is taught how to convert sparse tensors to meta
There's still a little more cleanup that needs to be done, but this is good for review.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82172
Approved by: https://github.com/eellison
Adds a new context manager `TorchRefsNvfuserCapabilityMode` for conditional rewrite of `torch.*` calls to `torch._refs.*` based on whether the decomposition consisting of prims supports nvFuser execution or not.
A new optional argument for `TorchRefsMode` is added - `should_fallback_fn`, a callable that returns whether the original `torch.foo` or the replacement `torch._refs.foo` should be used.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81764
Approved by: https://github.com/ezyang
This makes symbolic tracing tests for logsigmoid and xlogy start working again.
While I'm at it, add pin_memory and layout kwargs to empty; but they
don't actually do anything and raise an error if they are non standard.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82332
Approved by: https://github.com/eellison
Once CompositeImplicitAutograd gets registered to Python key, this will
ensure that tensor subclasses can interpose on these functions directly
rather than getting decomposed. We prefer not decomposing as these
functions are functional, but their implementations use inplace
operations (and are thus more difficult to deal with, unless you use
functionalization.)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82238
Approved by: https://github.com/zou3519, https://github.com/bdhirsh
This ref does more things than `torch.norm`, and it fixes a few bugs
that `torch.norm` has. This implementation and the `torch.norm`
implementation come to terms in the next PR of this stack
We put this PR before, as otherwise `test_decomp.py` was failing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81765
Approved by: https://github.com/ngimel
This PR is doing a few interrelated things, all of which are necessary to get correctness. Read the comment in torch/fx/experimental/proxy_tensor.py for the high level overview.
Let's break down the parts of this PR:
* Bug fix where `enable_torch_dispatch_mode` with `None` doesn't work. This make `enable_torch_dispatch_mode(current_mode.inner)` work which is the basis for how we temporarily disable fake tensor mode.
* Bug fix for when fake tensor mode is combined with a non-mode tensor subclass. This actually could be ablated from this PR but it affects where the logic for allowing non fake tensor inputs with lift goes, so it's all in here in one go. There are some relevant tests for the fix in fake tensor, but it turns out I didn't need this because I'm always using proxy tensors as a mode (which ensures the ordering is right.)
* New `lift_fresh` view operator. Note that like lift, we have to manually write the functionalize kernel for these functions.
* The actual change, which is to save constants when we see them in the proxy tensor mode, and then propagate them as we go (because otherwise you'll handle mutations on constants incorrectly--see test.)
This is mildly BC-breaking if anyone was previously interposing on
at::lift, but this operator was relatively new and I checked
functorch which has no explicit reference to lift. So I think it
should not be too disruptive.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81192
Approved by: https://github.com/samdow, https://github.com/bdhirsh
```
def f():
val = torch.tensor(float('inf'))
return torch.full((100, 100), val)
```
today we turn `val` into a ProxyTensor, and then complain when we try to turn `val` into a scalar.
We call `aten::lift` when we call `torch.tensor(5)`, so this just prevents those from being turned into ProxyTensors unnecessarily.
cc: @ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81024
Approved by: https://github.com/ezyang