Previously, we introduced new SymInt overloads for every function we wanted. This led to a lot of boilerplate, and also a lot of confusion about how the overloads needed to be implemented.
This PR takes a simpler but more risky approach: just take the original function and changes its ints to SymInts.
This is BC-breaking in the following ways:
* The C++ API for registering implementations for aten operators will change from int64_t to SymInt whenever you make this change. Code generated registrations in PyTorch do not change as codegen handles the translation automatically, but manual registrations will need to follow the change. Typically, if you now accept a SymInt where you previously only took int64_t, you have to convert it back manually. This will definitely break XLA, see companion PR https://github.com/pytorch/xla/pull/3914 Note that not all dispatch keys get the automatic translation; all the composite keys and Meta keys are modified to take SymInt directly (because they should handle them directly), and so there are adjustments for this.
This is not BC-breaking in the following ways:
* The user facing C++ API remains compatible. Even if a function changes from int to SymInt, the default C++ binding still takes only ints. (e.g., at::empty(IntArrayRef, ...). To call with SymInts, you must call at::empty_symint instead. This involved adding two more signatures to CppSignatureGroup; in many cases I refactored code to iterate over all signatures in the group instead of hard-coding the two that previously existed.
* This is TorchScript compatible; internally we treat SymInts as ints so there is no change to what happens at runtime in TorchScript. In particular, it's OK to reference an empty schema by its old type (using int types), as long as you're not doing string equality (which you shouldn't be), these parse to the same underyling type.
Structure of the PR:
* The general strategy of this PR is that, even when you write `SymInt` inside `native_functions.yaml`, sometimes, we will treat it *as if* it were an `int`. This idea pervades the codegen changes, where we have a translation from SymInt to c10::SymInt or int64_t, and this is controlled by a symint kwarg which I added and then audited all call sites to decide which I wanted. Here are some of the major places where we pick one or the other:
* The C++ FunctionSchema representation represents `SymInt` as `int`. There are a few places we do need to know that we actually have a SymInt and we consult `real_type()` to get the real type in this case. In particular:
* When we do schema validation of C++ operator registration, we must compare against true schema (as the C++ API will provide `c10::SymInt`, and this will only be accepted if the schema is `SymInt`. This is handled with cloneWithRealTypes before we check for schema differences.
* In `toIValue` argument parsing, we parse against the true schema value. For backwards compatibility reasons, I do still accept ints in many places where Layout/SymInt/etc were expected. (Well, accepting int where SymInt is expected is not BC, it's just the right logic!)
* In particular, because SymInt never shows up as type() in FunctionSchema, this means that we no longer need a dedicated Tag::SymInt. This is good, because SymInts never show up in mobile anyway.
* Changes to functorch/aten are mostly about tracking changes to the C++ API registration convention. Additionally, since SymInt overloads no longer exist, registrations for SymInt implementations are deleted. In many cases, the old implementations did not properly support SymInts; I did not add any new functionality with this PR, but I did try to annotate with TODOs where this is work to do. Finally, because the signature of `native::` API changed from int to SymInt, I need to find alternative APIs for people who were directly calling these functions to call. Typically, I insert a new dispatch call when perf doesn't matter, or use `at::compositeexplicitautograd` namespace to handle other caes.
* The change to `make_boxed_from_unboxed_functor.h` is so that we accept a plain IntList IValue anywhere a SymIntList is expected; these are read-only arguments so covariant typing is OK.
* I change how unboxing logic works slightly. Previously, we interpret the C++ type for Layout/etc directly as IntType JIT type, which works well because the incoming IValue is tagged as an integer. Now, we interpret the C++ type for Layout as its true type, e.g., LayoutType (change to `jit_type.h`), but then we accept an int IValue for it anyway. This makes it symmetric with SymInt, where we interpret the C++ type as SymIntType, and then accept SymInt and int IValues for it.
* I renamed the `empty.names` overload to `empty_names` to make it less confusing (I kept mixing it up with the real empty overload)
* I deleted the `empty.SymInt` overload, which ended up killing a pile of functions. (This was originally a separate PR but the profiler expect test was giving me grief so I folded it in.)
* I deleted the LazyDynamicOpsTest tests. These were failing after these changes, and I couldn't figure out why they used to be passing: they make use of `narrow_copy` which didn't actually support SymInts; they were immediately converted to ints.
* I bashed LTC into working. The patches made here are not the end of the story. The big problem is that SymInt translates into Value, but what if you have a list of SymInt? This cannot be conveniently represented in the IR today, since variadic Values are not supported. To work around this, I translate SymInt[] into plain int[] (this is fine for tests because LTC dynamic shapes never actually worked); but this will need to be fixed for proper LTC SymInt support. The LTC codegen also looked somewhat questionable; I added comments based on my code reading.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83628
Approved by: https://github.com/albanD, https://github.com/bdhirsh
- nondeterministic_seeded was not applied to enough functions. I added
some heuristics to codegen for identifying functions that are likely
to be random and added a bunch of these tags to functions. Not sure
I got all of them.
- Don't constant propagate through nondeterministic functions in FX
tracing.
It would be better to do some testing for the tag but this would be quite an effort.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83650
Approved by: https://github.com/bdhirsh, https://github.com/eellison
I was working on https://github.com/pytorch/torchdynamo/issues/80 and my
working hypothesis for what was causing the error was that proxy tensor
was not advertising correct dispatch keys, causing AMP to operate
differently when you traced. I could have fixed this directly by
replicating fake tensor's fix for setting dispatch keys to also apply to
proxy tensor, but I was like, "Why must I repeat myself."
This PR is the result. It completely deletes the ProxyTensor wrapper
subclass, so that when we are tracing, the tensors flowing through the
program are the *original* real or fake tensors, depending on what the
user requested in the top-level API. There is no more wrapping. To
store the Proxy objects necessary for actually doing tracing, I store
the property directly on the tensors. (Note: I never
clean up old entries from the map at the moment, this is easily fixed
by using a weak map)
Benefits of doing this:
* No more tip-toeing around no_dispatch() creation of new ProxyTensors;
we never create new tensors (except when we call the underlying func),
so you don't have to worry about accidentally tracing them.
* No more syncing up metadata from in place operators. In particular
https://github.com/pytorch/pytorch/issues/81526 is mooted
* This fixes https://github.com/pytorch/torchdynamo/issues/519 as we no longer need to teach proxy tensor to support sparse tensor.
* No more schlepping symbolic integers from the inner fake tensor to the
outer proxy tensor. If you can make a fake tensor with symbolic ints,
you're done, nothing else to do.
To avoid having to rewrite all of the guts, when I get to the actual
proxy tensor handler, I first "fetch" the stored ProxyTensor data from
the weakmap via a tree_map, and then operate on the consequent data as
before. A more optimized implementation is possible.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83330
Approved by: https://github.com/Chillee
We're on our way to deleting ProxyTensor entirely (see https://github.com/pytorch/pytorch/pull/83330 ), but before we can do that, we have to delete ProxySymInt first. Here's the plan.
Changes in torch.fx.experimental.symbolic_shapes
* The general idea is to do mode based tracing. This means we need a mode that can interpose on all SymInt operations. There are a few ways to do this, but I've done it the easy way: (1) I have a separate mode for SymInt operations specifically called SymDispatchMode, and (2) this mode operates on PySymInt (and not the basic SymInt which is user visible). I elided Int from the name because if we add SymFloats I want to use the same mode to handle those as well, and I used Dispatch rather than Function because this is the "inner" dispatch operating PySymInt and not SymInt (this is not a perfect analogy, but SymFunctionMode definitely seemed wrong as you still must go through the C++ binding.) The mode is entirely implemented in Python for ease of implementation. We could have implemented this more symmetrically to TorchFunctionMode in C++, but I leave that as later work; this API is unlikely to get used by others (unlike TorchFunctionMode). One downside to not doing the mode in C++ is that we still have to do the hop via a preexisting PySymInt to wrap; this is currently not a big deal as conversion to SymInts only really happens when there is already another SymInt floating around. SymDispatchMode is pared down from TorchDispatchMode; there is no ancestor tracking since I don't expect people to be mixing up SymDispatchModes.
* I made some improvements for tracing. When I invoke the SymDispatchMode handler, I would like constants to show up as constants, so they can be directly inlined into the FX graph (rather than going through a wrapping process first, and then the wrapped SymInt being used in the operation). To do this, I directly track if a PySymInt is a constant at construction time. Only wrapped PySymInts are constants.
* For convenience, PySymInts now support all magic methods that regular SymInts do. This is so that redispatch inside the SymDispatchMode can be written the idiomatic way `func(*args, **kwargs)` where func is an operator. The original names are retained for direct C++ calls.
Changes in torch.fx.experimental.proxy_tensor
* OK, so we got a new SymDispatchMode, so we define a ProxySymDispatchMode and activate it when we start tracing. This mode is currently unconditionally activated although technically we only need to activate it when doing symbolic tracing (it doesn't matter either way as there are no SymInts if you are not doing symbolic tracing).
* We delete ProxySymInt. To do this, we must now record the proxy for the SymInt some other way. Based on discussion with Chillee, it is more intuitive to him if the proxies are still recorded on the SymInt in some way. So we store them in the `__dict__` of the PySymInt, indexed by Tracer. An improvement is to make this a weak map, so that we remove all of these entries when the tracer dies. In an original version of this PR, I keyed on the mode itself, but tracer is better as it is accessible from both modes (and as you will see, we will need to fetch the map from both the ProxySymDispatchMode as well as the ProxyTorchDispatchMode.) The implementation of SymDispatchMode now simply retrieves the proxies, performs the underlying operation as well as the FX graph recording, and then records the output proxy to the PySymInt. Note that FX tracing does not work with proxies and SymInts, so we manually call `call_function` to ensure that the correct operations get recorded to the graph. This means conventional FX retracing with proxies only will not work with these graphs, but there wasn't really any reason to do this (as opposed to `make_fx` retracing) anyway. Constants are detected and converted directly into Python integers.
* SymInts can show up as arguments to tensor operations, so they must be accounted for in ProxyTorchDispatchMode as well. This is done by searching for SymInt arguments and converting them into proxies before the proxy call. This can be done more efficiently in a single `tree_map` but I'm lazy. The helper `unwrap_symint_proxy` conveniently implements the unwrapping in one place given a tracer; unfortunately it cannot be shared with SymDispatchMode as SymDispatchMode gets PySymInts, but ProxyTensorMode gets SymInts. Similarly, tensors that are returned from tensor operations can have SymInts in their shapes, which need fresh proxies allocated. To avoid leaking internal details of SymInt shape computation to the tensor operation graph, these SymInts are always given proxies derived from `x.size(dim)` call on their return tensor. We also need to do this for strides and numel but have not done so yet. Furthermore, we must avoid tracing internal SymInt calls while we run meta operations on the true operation; this is achieved by also disabling SymInt tracing on the inside of tensor tracing. This is analogous to how tensor tracing is disabled inside the implementation of tracing mode, but unfortunately we are unable to use the same mechanism (this would have been easier if the two modes could be combined somehow, and I am amenable to suggestions to try harder to achieve this.)
* Because there are no more ProxySymInts, we no longer need to do anything to unwrap SymInt. Furthermore, we do not need to reallocate ProxySymInts on class creation.
* If a bare SymInt without a Proxy is encountered, it is assumed that this must be a constant. `create_arg` handles this case. Non-constant free SymInts result in an assert error.
* The initial input handling in `dispatch_trace` involves traversing all of the input tensors, traversing over their shapes, and assigning proxies for the SymInts in shapes in the same way we handle proxies for the output tensors.
The preexisting testing is inadequate but will be better after I rebase past https://github.com/pytorch/pytorch/pull/82209
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83380
Approved by: https://github.com/samdow
This should fix a few of the errors I was seeing when I turned on functionalization in torchbench. It also fixes this AOTAutograd repro with resnet18:
```
import torch
from torchvision.models import resnet18
from functorch._src.compilers import nop
from functorch._src.aot_autograd import aot_module
from functorch.compile import config
config.use_functionalize = True
model = resnet18().cuda().half().to(memory_format=torch.channels_last)
input = torch.randn(256, 3, 224, 224, device='cuda', dtype=torch.float16) \
.to(memory_format=torch.channels_last).detach().requires_grad_(True)
input_expected = input.clone().detach().requires_grad_(True)
fn = aot_module(model, nop)
out = fn(input)
out_expected = model(input_expected)
print(torch.allclose(out, out_expected))
out.sum().backward()
out_expected.sum().backward()
print(torch.allclose(input.grad, input_expected.grad))
```
The problem was that functorch adds a decomp to the decomp table for `new_zeros`:
```
@register_decomposition(aten.new_zeros, aot_autograd_decompositions)
def new_zeros(inp, size, dtype=None, layout=None, device=None, pin_memory=None):
return torch.zeros(size, dtype=inp.dtype, device=inp.device)
```
When calling that decomp from inside of `ProxyTensorDispatchMode`, the ProxyTensorMode is already disabled, and `torch.zeros` doesn't take in any tensor-like arguments, so we never end up dispatching back into python again.
The way that manifests is that the output of `new_zeros()` gets baked as a constant into the AOTAutograd FX graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83207
Approved by: https://github.com/ezyang
This allows you to directly call into the CompositeImplicitAutograd
implementation of an operator, *without* changing any aspects of the
dispatcher state. In particular, you can use this to recursively call
into a decomposition, dispatching back to your tensor subclass/mode
as desired.
Hypothetically, we should also make these available in the
decompositions dictionary, but I'm leaving this as future work as
enumerating these decompositions is annoying (as operators are lazily
registered.)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83075
Approved by: https://github.com/albanD
Reference: #69991
This adds a new CompositeExplicit Operator `at::assert_all_true` (also exposed in Python) to check the truthiness of a tensor and throw an error based on that.
This helps us mitigate `TORCH_CHECK(t.all().item<bool>(), "err_msg")` pattern which is not composite compliant.
Using the mentioned operator, we fix `quantile` and `nanquantile` to be Composite Compliant.
Question: Should it be documented?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81767
Approved by: https://github.com/zou3519
This gives us tests for fake (all passing) and symbolic (not all passing).
I needed to add a gadget for xfail'ing tests in symbolic. It might be
generally useful, let me know what you think.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82746
Approved by: https://github.com/eellison
I intend to run all of the proxy tensor tests on each of our tracing
modes, but to do this I have to make the tests parametrized on
tracing mode first. This does that refactor, without adding any
new tests.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82739
Approved by: https://github.com/eellison
This PR relands sym_numel #82374 and fixes the ios build break in this commit : 8cbd0031c5
which was a type mismatch in an equality.
### Description
<!-- What did you change and why was it needed? -->
### Issue
<!-- Link to Issue ticket or RFP -->
### Testing
<!-- How did you test your change? -->
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82731
Approved by: https://github.com/malfet
Add support for sparse fake tensors.
- The testing strategy is to run a fake tensor cross ref test on `test_sparse.py`. This is necessary because OpInfo sparse coverage is completely nonexistent. We could have tried to turn on cross ref testing globally for all files, but that would be very time consuming and the tests I'm interested in are mostly in this file. There are some exclusions in testing for things that don't work.
- I make fake tensor converter raise a UnsupportedFakeTensorException if the meta converter fails to do a conversion (which can happen in a relatively large number of situations).
- I relax fake tensor invariants so that you can make a fake tensor from a meta tensor. This is useful because in the cross ref test sometimes we operate on meta tensors.
- Fake tensor wrapping is improved to handle the case when a function doesn't return any tensors
- Meta converter is taught how to convert sparse tensors to meta
There's still a little more cleanup that needs to be done, but this is good for review.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82172
Approved by: https://github.com/eellison
Adds a new context manager `TorchRefsNvfuserCapabilityMode` for conditional rewrite of `torch.*` calls to `torch._refs.*` based on whether the decomposition consisting of prims supports nvFuser execution or not.
A new optional argument for `TorchRefsMode` is added - `should_fallback_fn`, a callable that returns whether the original `torch.foo` or the replacement `torch._refs.foo` should be used.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81764
Approved by: https://github.com/ezyang
This makes symbolic tracing tests for logsigmoid and xlogy start working again.
While I'm at it, add pin_memory and layout kwargs to empty; but they
don't actually do anything and raise an error if they are non standard.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82332
Approved by: https://github.com/eellison
Once CompositeImplicitAutograd gets registered to Python key, this will
ensure that tensor subclasses can interpose on these functions directly
rather than getting decomposed. We prefer not decomposing as these
functions are functional, but their implementations use inplace
operations (and are thus more difficult to deal with, unless you use
functionalization.)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82238
Approved by: https://github.com/zou3519, https://github.com/bdhirsh
This ref does more things than `torch.norm`, and it fixes a few bugs
that `torch.norm` has. This implementation and the `torch.norm`
implementation come to terms in the next PR of this stack
We put this PR before, as otherwise `test_decomp.py` was failing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81765
Approved by: https://github.com/ngimel
This PR is doing a few interrelated things, all of which are necessary to get correctness. Read the comment in torch/fx/experimental/proxy_tensor.py for the high level overview.
Let's break down the parts of this PR:
* Bug fix where `enable_torch_dispatch_mode` with `None` doesn't work. This make `enable_torch_dispatch_mode(current_mode.inner)` work which is the basis for how we temporarily disable fake tensor mode.
* Bug fix for when fake tensor mode is combined with a non-mode tensor subclass. This actually could be ablated from this PR but it affects where the logic for allowing non fake tensor inputs with lift goes, so it's all in here in one go. There are some relevant tests for the fix in fake tensor, but it turns out I didn't need this because I'm always using proxy tensors as a mode (which ensures the ordering is right.)
* New `lift_fresh` view operator. Note that like lift, we have to manually write the functionalize kernel for these functions.
* The actual change, which is to save constants when we see them in the proxy tensor mode, and then propagate them as we go (because otherwise you'll handle mutations on constants incorrectly--see test.)
This is mildly BC-breaking if anyone was previously interposing on
at::lift, but this operator was relatively new and I checked
functorch which has no explicit reference to lift. So I think it
should not be too disruptive.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81192
Approved by: https://github.com/samdow, https://github.com/bdhirsh
```
def f():
val = torch.tensor(float('inf'))
return torch.full((100, 100), val)
```
today we turn `val` into a ProxyTensor, and then complain when we try to turn `val` into a scalar.
We call `aten::lift` when we call `torch.tensor(5)`, so this just prevents those from being turned into ProxyTensors unnecessarily.
cc: @ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81024
Approved by: https://github.com/ezyang