Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64432
Original PR description + feedback here: https://github.com/pytorch/pytorch/pull/63048
I've addressed all of the feedback in the original PR and made some pretty large changes, listed below.
**Table of Contents**
- Starting points
- List of the main changes from the original PR
- Next Steps
- Example codegen output (for a view, mutation, and view+mutation op)
**Starting Points**
A good place to start when looking through the PR:
* Alban mentioned that this is a useful mental model (thanks Ed for originally making this clear to me). Semantically, the pass currently does THREE things, which are all needed by functorch - all fused together into one big pass.
* (a) alias removal, which replaces {view} calls with {view}_copy calls, and manually tracks aliasing information, so that when one tensor is mutated, we re-apply the same mutation to all of the aliases. This is the bulk of the work - once this is done, the next 2 things are trivial to implement.
* (b) mutation removal, which is easy to do once we know that there are no aliases. Every mutation `a.add_(b)` becomes `a.replace_(a.add(b))`
* (c) reapplying views: all of the `{view}_copy` calls are replaced with `{view}` calls again. This is an optimization that we can make specifically for functorch (and strided backends), that only care about mutation removal and not alias removal
* XLA and Vulkan only want (a), or (a) + (b). Later, we'll want to split this out so that you can actually opt into different versions of this logic.
* There is currently no {view}_copy replacement, because the pass just <replace views with copies> and <replace copies with views> steps have been combined. Later, we'll want to actually implement {view}_copy variants of each view operator, probably with codegen.
* documentation breadcrumb 1, in `FunctionalTensorWrapper.cpp`: https://github.com/pytorch/pytorch/pull/64432/files#diff-a0bac99bf205dba5b94cb64fc2466d3d55d991887572f9cd6a02e27b3a91dd60R59 (you might have to expand the `FunctionalTensorWrapper.cpp` file, which GitHub closes by default because it's large)
* documentation breadcrumb 2, in `FunctionalTensorWrapper.h`: https://github.com/pytorch/pytorch/pull/64432/files#diff-c945c71a4ccac65871f24a912e8904f9a5088b24a32e636727ea9c8fe920708aR12
* Reading through the codegen output at the bottom of this description.
**Main changes from the original PR**
(1) I use lambdas instead of a giant enum to handle all of the different views.
This results in less boilerplate per view op (and more stuff that can be codegen'd). Every `ViewMeta` object now contains a `forward` and `reverse` lambda, that knows how to replay the view and its inverse. This makes the actual code that executes the replaying logic a lot less boilerplate-y (see `Alias::sync_update_operations` and `FunctionalTensorWrapper::sync_`)
(2) Every tensor during the functionalization pass is always wrapped in a `FunctionalTensorWrapper`.
This is potentially unnecessary for Vulkan/XLA, and will have a mild perf impact, but for now this PR just targets the functorch use case. I previously had a complicated design a (`FunctionalTensorImplBase` class) to avoid needing the wrapper for XLA, but it had some subtleties that are gonna require more thought to fix, so I'm pushing that off for now.
(3) `FunctionalTensorWrapper` objects accurately report stride information.
It's a little annoying to do this though, because the logic that calculates stride info for each view isn't easily separated from the actual view kernels in core, `at::native::{view}`. I do this by adding logic in each `at::functionalization::{view}` kernel to call the reference implementation `at::native::{view}`. I don't do anything with the output aside from taking it's size/stride/storage_offset to set the actual output tensor's size/stride/storage_offset correctly. There's another annoying part to this: I'm pretty sure that we want to pass in the actual *wrapper* tensors directly into the native kernels, not their inner unwrapped values. But there are some `at::native::{view}` kernels that call other tensor methods, which re-invokes the dispatcher, calling functionalization/functorch kernels that try do the unwrapping.
To do this, right now I have an `AutoDispatchDirectlyToNative` guard that basically ensures that any tensor methods called inside of the at::native::{view} op always redispatch straight to the CPU kernel (which will be another at::native:: kernel). This feels kind of heavy handed, but I'm not sure of a better way to do it.
(4) `FunctionalTensorWrapper` objects accurately report aliasing information.
There's a new `FunctionalStorageImpl` class (subclass of `StorageImpl`) that allows tensors in the functionalization pass to accurately alias storage. If two tensors `a` and `b` in a functionalized program are views of one another, then `a.storage.is_alias_of(b.storage)` should return true. I added this in a pretty similar way to how meta tensors allocate storage, although I don't pass in an actual allocator (I think this is fine because you should never resize a functional tensor's storage).
One thing I'm not sure about - should `FunctionalTensorWrapper` set `storage_access_should_throw_`: (a) always, (b) never, (c) only if its wrapped tensor has it set.
Right now I have it not set, mostly because calling the reference view functions (`at::native::{view}`) requires looking at the storage. But that means that if you try to access storage from python in a functionalized program, you'll get silent garbage instead of an error. Related question: are we planning on exposing meta tensor storage to python in the future (even though it contains garbage)?
(5) better docs :)
**View operator coverage**
(6) The functionalization pass now gets math-composite view ops for free.
I didn't add the `Functionalize` dispatch key to the composite set, because I don't want composite ops like `torch.ones` to get decomposed before hitting the functionalization pass. Instead, I added codegen to manually register the `at::native::` kernels of composite view ops. This is a little hairy, because the names of the `at::native::` kernels aren't easily accessible. They're stored in a `Dict[DispatchKey, BackendIndex]`. I made a best-effort attempt to get each view kernel's name, basically by assuming that every view op has either a composite or cpu implementation.
There's also a hardcoded list of composite view ops in `gen_inplace_or_view_type.py`, but it looks like it's wrong. This is probably worth rationalizing later, but instead I created a new list of the "complete" set of composite view ops, and preserved the old set by hardcoding the delta between the two sets.
(7) I've added codegen for ops that are both views AND mutations, like `transpose_()` (why do we even have these {emoji:1f622}).
From some light testing, it looks like they work correctly with one caveat: I had a hard time ensuring that functorch programs that mutate their inputs using ops like `transpose_()` preserve the input mutations after the program finishes running. For (in my corresponding functorch branch) I emit a warning when this happens, and just don't preserve the mutation
(8) I added `{view}_inverse` implementations for every view op, in `FunctionalInverses.cpp`.
These are needed to take mutations made to views and replay them back onto the base. To reduce boilerplate, the codegen generates function declarations for each `{view}_inverse` function, so you get a nice compiler error when someone eventually adds a new view op.
The only view ops currently not supported are (a) as_strided, and (b) the sparse view ops (values()/indices()).
I can add support for as_strided, but it needs an `as_strided_inverse()` function. That will look really similar to the `as_strided_backward()` function in FunctionsManual.cpp, but it has some noticeable differences: we basically want an `as_strided_embed` for autograd and `as_strided_scatter` for functionalization. We also will probably need them to be primitives w.r.t to autograd, since the currently implementation for autograd uses view().copy_() calls that XLA won't be able to handle. I'm wondering if anyone has any objections, but otherwise I can make those change (which will require writing backward formulas for `as_strided_embed` and `as_strided_scatter`).
I did a bunch of manual testing that all looks pretty good, but it's definitely not fully tested. Ed pointed out that once XLA uses this pass (or at least once there's a POC), we can just run the existing xla view test suite. Hopefully that delay is okay - if it's not, maybe we can think about using OpInfos similar to how functorch uses them for testing.
Note: there's some duplication with autograd's view code. Every `{view}_inverse` implementation is really similar to the implementation for that view listed in `derivatives.yaml`. There are some major differences though:
* the autograd implementations over those backwards functions (like `permute_backwards()`, in `FunctionsManual.cpp`) internally call other view ops. For functoinalization, we want them to (eventually call `{view}_copy` operators).
* For view ops that take a subset of the original storage, like `slice/select/diagonal/as_strided()`, the autograd backward functions fill the "spaces" in the inverse call with zeroes. For functionalizations, we want to fill them with the value of `base` at those positions. It looks like this currently applies to 6 total ops (since we can ignore composites):
* select
* slice
* diagonal
* as_stridied
* split
* split_with_sizes
A nice end state would probably be for the autograd + functoinalization codegen to both look at the same yaml (either `derivatives.yaml`, or something else), and automatically generate the right thing. I didn't leave that in scope for this PR though.
**Current State + Next Steps**
There are a bunch of followups after this PR eventually lands. Roughly in order:
* Use the current pass to register problematic composite ops in functorch. Also, nested `functionalize()` calls aren't supported yet (I mostly just need to remove some debug asserts and test it).
* Work on freeing up dispatch key space in the by deduplicating the `{backend}`/`Autograd{backend}`/`Sparse{backend}`/`Quantized{backend}` keys
* Once we have more dispatch keys, split up this pass into 3 pieces - it's currently fused, and doesn't do the right thing for vulkan/XLA. Specifically, all of the `{view}` calls in the current pass's view-replay logic should turn into `{view}_copy` calls that vulkan/XLA know how to implement, and there will be separate passes for (a) removing mutations, and (b) turning `{view}_copy` calls back into `{view}` calls. For Vulkan, we eventually want a pass that ONLY removes aliasing and view calls, and doesn't remove mutations. We can also probably make the 2 new passes user dispatch keys to save dispatch key space, if they'll only be used by functorch anyway.
* Do more of a dive on perf for the vulkan/xla use cases. There are several areas to improve perf with varying levels of effort required. The simplest one that I'll probably do regardless is to codegen the out-of-place kernels instead of using a boxed fallback. Getting a POC working for xla will also be useful to test the view operator coverage.
**Example Codegen Output**
View Op:
```
::std::vector<at::Tensor> split_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, int64_t split_size, int64_t dim) {
auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
::std::vector<at::Tensor> out;
{
at::AutoDispatchBelowFunctionalize guard;
auto tmp_output = at::redispatch::split(ks & c10::after_func_keyset, self_, split_size, dim);
out = at::functionalization::impl::wrapFunctionalTensor(tmp_output);
// I'm fusing the [alias removal], [mutation removal], [add views back] passes together.
// Later, we'll want to turn them into separate passes (since e.g. vulkan only cares about alias removal).
}
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
[split_size, dim](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
return base.split(split_size, dim)[mutated_view_idx];
},
[split_size, dim](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
return at::functionalization::impl::split_inverse(base, mutated_view, mutated_view_idx, split_size, dim);
}
);
at::functionalization::impl::set_view_meta(out, self, view_meta);
at::AutoDispatchDirectlyToNative native_guard;
::std::vector<at::Tensor> reference_tensor_output = at::native::split(self, split_size, dim);
at::functionalization::impl::set_strides(out, reference_tensor_output);
return out;
}
```
Mutation Op:
```
at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
at::functionalization::impl::sync(self);
at::functionalization::impl::sync(other);
auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self);
auto other_ = at::functionalization::impl::unwrapFunctionalTensor(other);
at::Tensor tmp_output;
{
at::AutoDispatchBelowFunctionalize guard;
// The functionalization pass explicitly doesn't pass out= parameters to the redispatch
tmp_output = at::redispatch::add(
ks & c10::after_func_keyset, self_, other_, alpha);
}
self.replace_(tmp_output);
at::functionalization::impl::maybe_add_update(self);
return self;
}
```
View + Mutation Op:
```
at::Tensor & transpose_(c10::DispatchKeySet ks, at::Tensor & self, int64_t dim0, int64_t dim1) {
at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
[dim0, dim1](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor {
return base.transpose(dim0, dim1);
},
[dim0, dim1](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor {
return at::functionalization::impl::transpose_inverse(base, mutated_view, dim0, dim1);
}
);
at::functionalization::impl::mutate_view_meta(self, view_meta);
// See Note [Propagating strides in the functionalization pass]
// Directly update the sizes/strides/storage_offset fields on self using the inplace call.
// I need the guard because I don't want the at::native kernel to end up calling more functionalization/functorch kernels.
// Its only job is to directly compute the output size/stride/storage_offset metadata.
at::AutoDispatchDirectlyToNative native_guard;
at::native::transpose_(self, dim0, dim1);
return self;
}
```
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D31942093
Pulled By: bdhirsh
fbshipit-source-id: b95598dae35dd1842fa8b1d8d1448332f3afaadf
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63094
This PR:
- Moves `FileManager` and its dependencies (`assert_never` and other imports) to `utils.py`, and updates all of the call-sites with the fresh imports
- Passes the list of NativeFunction objects into `gen_trace_type` directly, instead of requiring the function to regenerate it (we already have it)
The purpose of the reshuffling is to avoid circular dependencies in the next PR, where I add codegen for the functionalization pass, which gets called from `gen.py` (but depends on some stuff from the autograd codegen - in partulcar, the list of view ops).
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D31942096
Pulled By: bdhirsh
fbshipit-source-id: 36118facae61f25f8922bb43ad2818c80b53504e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61746
**Summary**
This commit introduces a new feature for structured kernels that allows
kernels to declare quantities as "precomputed" in
`native_functions.yaml`, compute them once in the `meta` function and
reuse them again in the `impl`. The names and types of these quantities
are used to generate code for a struct containing them that the `meta`
function must return. In the case of a handful of surveyed kernels
(`all,`, `any`, `avg_pool2d`), these quantities that are used both in
the `meta` and `impl` have the same meaning as certain kernel arguments
and in fact supersede them. Accordingly, the correspondence between a
kernel argument and the precomputed elements that supersede it is also
captured in `native_functions.yaml`. This information is used to unpack
the struct returned by `meta` and pass its contents correctly to the
`impl` function.
The primary goal is to avoid recompute and enhance developer experience
(e.g. sometimes people can forget to compute these elements while
porting a kernel).
Test Plan: Imported from OSS
Reviewed By: tugsbayasgalan
Differential Revision: D30407831
Pulled By: SplitInfinity
fbshipit-source-id: 00975525ea373721fe52d06f75cd4ac91f3dc556
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62188
These parts of the `set_output` code are identical for all operators in the
kernel registration files. So, this moves them from being copied into every
class to two helper functions at the top of the file.
Test Plan: Imported from OSS
Reviewed By: soulitzer
Differential Revision: D29962045
Pulled By: albanD
fbshipit-source-id: 753b8aac755f3c91b77ffa2c30a89ac91a84b7c4
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62187
This file can take 3 minutes on its own to compile, and after
python_functions.cpp is the second limiting factor for compile time of
`libtorch_python` on a 32-core threadripper. This splits it into 3 files that
take around 1 minute each to compile.
Test Plan: Imported from OSS
Reviewed By: H-Huang
Differential Revision: D29962048
Pulled By: albanD
fbshipit-source-id: 99016d75912bff483fe21b130cef43a6882f8c0e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63505
This isn't a public operator, just a helper function used in CUDA_tensor_apply.
Test Plan: Imported from OSS
Reviewed By: mruberry
Differential Revision: D30441305
Pulled By: ngimel
fbshipit-source-id: 84fabc701cbd8479e02d80f373a3dd62d70df2ce
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62278
In `Operators.h` we're using `str(BaseOperatorName)`, while in
`OperatorsEverything.cpp` we're using `str(OperatorName)`. e.g.
```
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::abs")
```
vs
```
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(abs_out, name, "aten::abs.out")
```
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Differential Revision: D29962047
Pulled By: albanD
fbshipit-source-id: 5a05b898fc734a4751c2b0187e4eeea4efb0502b
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62185
This file can take 5 minutes on its own to compile, and is the single limiting
factor for compile time of `libtorch_cpu` on a 32-core threadripper. Instead,
sharding into 5 files that take around 1 minute each cuts a full minute off the
overall build time.
This also factors out the `.findSchemaOrThrow(...).typed` step so the code can
be shared between `call` and `redispatch`.
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Differential Revision: D29962049
Pulled By: albanD
fbshipit-source-id: be5df05fbea09ada0d825855f1618c25a11abbd8
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62184
File sharding is currently implemented twice, once for VariableType and once for
TraceType. This refactors the implementation into `FileManager` and also changes
it so template substitution is only done once and shared between the sharded
file and the "Everything" file.
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Differential Revision: D29962050
Pulled By: albanD
fbshipit-source-id: 7858c3ca9f6e674ad036febd2d1a4ed2323a2861
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58065
This PR replaces the existing code-generated CPU fallback kernels that XLA uses with a single boxed CPU fallback.
Current state: there are a couple different design ideas that I want to point out, but the logic for the actually kernel is mostly done and passing tests.
### Design
To preface, I'm not 100% tied to the current design and I'm putting the PR up now for opinions and totally open to alternatives, some of which I listed below. Actually after writing this description, I'm leaning toward the following changes:
* Confirm whether or not we can remove all C++ logging info directly in the yaml.
**Current Design**
All of the CPU fallback codegen is deleted. In its place, XLA (and other external backends, later) can choose to opt into a CPU fallback by adding the following code in a C++ file. I have an corresponding [xla-side PR with the xla changes](https://github.com/pytorch/xla/pull/2945/files#diff-1a005c10039f0cb11130a3b740f5de716d2f10acaea121017016025861886798R1).
There's no actual requirement to split up the code into a .h and .cpp file, but that's necessary in the XLA case because they sometimes need to call the fallback directly from their handcrafted kernels.
```
// xla_cpu_fallback.h
#include <ATen/native/CPUFallback.h>
...
void xla_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
...
```
```
// xla_cpu_fallback.cpp
#include "torch_xla/csrc/aten_cpu_fallback.h"
...
void xla_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
// Do custom logging here
...
// Call the actual boxed CPU fallback.
at::native::cpu_fallback(op, stack);
}
TORCH_LIBRARY_IMPL(_, XLA, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&xla_cpu_fallback>());
}
```
Now that the fallback is exposed in the backend, they can call it directly. Doing so requires converting from an unboxed to a boxed context, which we provide a utility function before. E.g.:
```
#include <ATen/native/CPUFallback.h>
at::Tensor addmm(const at::Tensor& self,const at::Tensor& mat1,const at::Tensor& mat2,const at::Scalar& beta,const at::Scalar& alpha) {
....
if (...call_fallback...) {
return at::native::call_fallback_fn<&xla_cpu_fallback, decltype(at::addmm)>::call("aten::addmm", self, mat1, mat2, beta, alpha);
}
...
}
```
That `decltype(at::addmm)` logic isn't actually used everywhere in the xla-side PR yet, since you hit issues with overloads. I could use it everywhere once #58092 lands.
**Alternatives: The API for calling the CPU fallback directly is ugly, can we make it nicer?**
We could change the api to use `at::redispatch`, which would make it look something like this:
```
at::Tensor addmm(const at::Tensor& self,const at::Tensor& mat1,const at::Tensor& mat2,const at::Scalar& beta,const at::Scalar& alpha) {
....
if (...call_fallback...) {
return at::redispatch::addmm(c10::DispatchKeySet(c10::DispatchKey::CPUFallback), self, mat1, mat2, beta, alpha);
}
...
}
```
Which definitely feels cleaner, but also requires adding a new DispatchKey just for this use case. Conditionally calling the CPU fallback doesn't sound like a hugely important use case, so I don't know if giving up one of our 64 dispatch key slots is worth the API improvement. Totally open to other opinions though!
Another more mild improvement that would avoid having to pass operator string names (including overloads) around would be to codegen (yet another) namespaced API. Something like this:
```
at::Tensor addmm(const at::Tensor& self,const at::Tensor& mat1,const at::Tensor& mat2,const at::Scalar& beta,const at::Scalar& alpha) {
....
if (...call_fallback...) {
return at::fallback::addmm<&xla_cpu_fallback>(self, mat1, mat2, beta, alpha);
}
...
}
```
Writing that out actually I actually like it more (I think it'll let us get rid of `decltype(...)`). Maybe that is nice enough to warrant a new codegen API - I haven't tried adding that yet, but if people like it I'm happy to try it out.
**More alternatives**
The current design also involves the backend manually writing and registering the boxed fallback themselves, but an alternative would be for us to do it in codegen too: they would just need to pass in all of the C++ logging that they want done in the fallback, directly through the yaml. The main downsides:
* Backend code that wants to call the fallback needs to abide by whatever convention our codegen uses to name the generated boxed fallback.
* Passing custom C++ logging through yaml is just more fragile: right now xla uses an `iostream` to log each tensor arg in the operator, so we'd have to either force other backends into the same convention or figure something else out later.
To be fair, we actually already do that: XLA has custom per-tensor-arg logging for all of the generated `out` wrappers in the codegen, which we do by passing their C++ logging info through the yaml. This seems unnecessary though, since `out` wrappers just call into a functional kernel, which is hand written with its own custom logging. So my take is: try to remove custom C++ logging from the yaml, and if it turns out to be really necessary, then we may as well take advantage of that to codegen the fallback.
### Performance impact
While ops that fall back to CPU aren't exactly hot path, we probably don't want to use a boxed fallback if it turns out to be an absolute perf killer.
I ran my benchmarks using callgrind, benchmarking both `at::add` and `at::add_out` run on XLA. My callgrind benchmark for `at::add` can be found here (the add_out benchmark looks basically the same): https://www.internalfb.com/phabricator/paste/view/P415418587. I created the benchmark by hacking the existing xla C++ test build scripts and throwing in a reference to callgrind.
I also attached the full callgrind output for each benchmark; the full output is actually pretty noise and hard to parse, but I focused on everything underneath the `at::add()` call in the output, which was much more stable. My guess is that it's due to some heavyweight async startup processing that xla does.
`at::add`:
before: 88,505,130 instructions. Full output: https://www.internalfb.com/phabricator/paste/view/P415421001
after: 102,185,654 instructions. Full output: https://www.internalfb.com/phabricator/paste/view/P415421273
delta: ~15.5% increase
`at::add_out`:
before: 63,897,395 instructions. Full output: https://www.internalfb.com/intern/everpaste/?handle=GBrrKwtAPlix9wUEAOZtrFXpdO5UbsIXAAAz
after: 73,170,346 instructions. Full output: https://www.internalfb.com/phabricator/paste/view/P415423227
delta: ~14.5% increase
High level takeaway: A framework overhead increase of 10-20% doesn't seem too horrible for the CPU fallback use case.
For structured, functional ops that requires a CPU fallback, we're actually in an unfortunate situation: we're doing even more work than necessary. Our codegen automatically creates a `CompositeExplicitAutograd` kernel which calls into the `out` operator. So the extra work that we end up doing is:
* An extra dispatcher hop: (at::add -> CompositeExplicitAutograd -> CPUFallback -> at::native::add) instead of (at::add -> CPUFallback -> at::native::add)
* An unnecessary tensor allocation (the CompositeExplicitAutograd kernel uses at::empty() to create an output tensor, which is immediately overwritten by the CPU fallback)
* An unnecessary meta() call (the CompositeExplicitAutograd kernel calls it to create the output tensor, but we call it again in the CPU kernel).
* unboxing->boxing->unboxing logic (this is the only strictly required piece)
There are definitely ways to avoid the unnecessary work explained above: one would be to give the boxed fallback higher priority than composite keys (there's [an issue for it here](https://github.com/pytorch/pytorch/issues/55104)), and codegen fallthroughs for all composite ops. It'll require more infra to set up, so I see it as more of a perf knob that we can apply if we need it later.
Unfortunately I couldn't dig much deeper into the differences aside from the aggregate change in instructions, since it looks like callgrind fudged some of the instruction attribution (`at::to_cpu` takes up a ton of instructions, but I don't see any attribution for the `at::native::add` kernel anywhere).
Test Plan: Imported from OSS
Reviewed By: jbschlosser
Differential Revision: D28833085
Pulled By: bdhirsh
fbshipit-source-id: 537ebd5d7fb5858f1158764ff47132d503c3b92b
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60214
Relanding this PR, but with a fix for windows cuda builds (example failure in master here: https://github.com/pytorch/pytorch/runs/2852662871)
This is identical to the original PR except for one change in `tools/codegen/gen.py`: `static constexpr` -> `static CONSTEXPR_EXCEPT_WIN_CUDA`
This actually took a while to figure out, until I tracked down a previous pytorch PR that encountered a similar issue: https://github.com/pytorch/pytorch/pull/40675
This reverts commit 6d0fb85a62.
Test Plan: Imported from OSS
Reviewed By: ezyang
Differential Revision: D29213932
Pulled By: bdhirsh
fbshipit-source-id: b90c7c10e5a51f8d6173ddca673b418e5774c248
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59115
This PR beefs up the `at::_ops::` API as a source of truth for compile-time information about each operator.
### Changes
For every op defined in native_functions.yaml, e.g. `at::_ops::add_Tensor` previously defined an unambiguous function; effectively an unambiguously named version of the C++ API that you could decltype() successfully because it had no overloads with a user-facing macro: `decltype(ATEN_FN2(add, Tensor)) // expands to decltype(at::_ops::add_Tensor)`.
Now, `at::_ops::add_Tensor` is a struct containing a few static fields and methods (declared in `Operators.h`, defined in `Operators.cpp`):
```
struct TORCH_API add_Tensor {
using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const at::Scalar &);
using ptr_schema = at::Tensor (*)(const at::Tensor &, const at::Tensor &, const at::Scalar &);
static constexpr const char* name = "aten::add";
static constexpr const char* overload_name = "Tensor";
static constexpr const char* schema_str = "add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor";
static at::Tensor call(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha);
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & ot
};
```
What used to be the function `at::_ops::add_Tensor` can now be accessed as `at::_ops::add_Tensor::call`, and I've added a new macro to access the entire struct (naming suggestions welcome) - `ATEN_OP2(add, Tensor)`.
### Motivation
There were two motivations for this change:
**Codegen refactor**
The `at::_ops::` API as it exists now is (yet another) C++ entry point into the dispatcher, in addition to the Function, Method, and Redispatch APIs. Instead, after this PR, the existing three API's are all inline-able wrapper API's that call into the `at::_ops` API to do the real work. The function and method API's call into `at::_ops::{op}::call`, while the redispatch API calls into `at::_ops::{op}::redispatch`.
This will hopefully make it easier to pile in any future C++ API's that we want to code-generate. It also means that stuff like the string name, overload name, and schema of each operator is consolidated in a single place, rather than having the codegen hardcode various strings in multiple codegen output files.
**Extra compile-time metadata**
In the [boxed CPU fallback PR](https://github.com/pytorch/pytorch/pull/58065/files#diff-c9b55f0d692a9bea8019c6f19bc46877f1efa0f9d4fc2086cf299b52768343b4R31) above this in the stack, I added a new API that external backends can use to call directly into their boxed fallback from an unboxed context. Adding extra metadata to `at::_ops` means that XLA's usage of that API doesn't require passing in the string name and overload of each name as arguments; we can just infer them.
The updated API looks like this (see [the XLA-side PR ](https://github.com/pytorch/xla/pull/2945/files#diff-5e65c3c1d847191cb691d1874732e971f09fa1aad7a980a555c3b0504a5b6470R250) for more examples)
```
return at::native::call_fallback_fn<&xla_cpu_fallback, ATEN_OP2(add, Tensor)>::call(a, b, 1.0);
```
**Characteristics of the `at::_ops` API**
(I also commented this in the codegen)
(1) It follows the Dispatcher API.
This means, e.g., that it takes in the expanded arguments rather than `TensorOptions`. This is kind of necessary for perf, if we want to `at::_ops` to serve as the main implementation of the existing C++ API's. For example: if it followed the C++ API, then all of the faithful C++ factory functions would need to wrap their arguments into TensorOptions only to unwrap them again.
(2) Overload names are disambiguated.
This is the same as before; it's helpful for pytorch extenders who would like to decltype() an aten operator, that has overloads, e.g. decltype(at::_ops::mul_Tensor::call)
(3) No argument defaulting is allowed.
This is more of an implementation detail to avoid #include cycles, since TensorBody.h (which defines the Tensor class) needs to include this file. The #include situation is precarious though!
(4) manual_cpp_bindings and faithful names are not included in the API.
I think that this is one we have a choice with. This applies to stuff like __dispatch__is_complex(), and add_outf(). These aren't "real native_functions.yaml ops", they're just additional functions provided by the C++ API. They're implemented as wrappers in Functions.h that call into the actual operators defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call(). This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher. It also means that `ATEN_OP2(add, out)` is automatically faithful and takes its out argument at the end (this is just because it follows the dispatcher API).
**Details**
Instead of codegen'ing the existing 3 API's in `Functions.cpp`, `TensorMethods.cpp` and `RedispatchFunctions.cpp`, I codegen them directly into the headers: `Functions.h`, `TensorBody.h`, and `RedispatchFunctions.h`. I mostly did this for perf, since we want to avoid introducing an extra function call in the hot path of every operator. These functions are also now all one-liners that call into `at::_ops`, so the compiler should just inline them all anyway.
The main downside in doing that though was that I had to bend over backwards in a few cases to avoid cyclical #include statements. The issue is that `TensorBody.h` now includes `Operators.h` (because the codegen'd method API is implemented by calling into `at::_ops`), but `TensorBody.h` also includes the definition of the Tensor class. That means that `Operators.h` can't be aware of the Tensor class; it needs to forward declare everything and avoid using the Tensor class directly. To fix cyclic includes, I had to:
- Not allow defaulting in the `at::_ops` API
- Move some code that was called when translating from C++ to Dispatcher API's directly into the codegen template (`check_tensor_options_and_extract_memory_format`)
It's not great, but I don't think this specific include cycle will break down in the near future; the only code that we need to call before getting to `Operators.cpp` is the translations from various API's to the dispatcher API; there aren't many of them, and there's no major reason for them to live an external utils file somewhere.
Moving the code into the headers also meant that the codegen no longer needs to deal with `Functions.cpp`/`TensorMethods.cpp`/`RedispatchFunctions.cpp`. All of the functions that used to be defined in `TensorMethods.cpp` seemed small enough for me to lump into `TensorBody.h`, but some of the functions in `Functions.cpp` looked pretty big to put in a header, so I moved the file to `aten/src/ATen/native/Functions.cpp`.
It might be worth keeping `TensorMethods.cpp` there and leaving it too, in-case we have any beefy hand-written tensor methods that we don't want to put in a header.
**Perf**
I ran a few benchmarks in callgrind, and didn't see a noticeable instruction count change when calling `at::add()`. I also saw in the output that `at::add()` was successfully getting inlined.
There's also probably a light risk of binary size increase; I think that there's a binary size regression test that I can run in phabricator (going to try it). I can also try inspecting `libtorch.so` directly and seeing if it's any bigger, but my hope is that the inline-ing means that we aren't generated separate symbols for `at::add` and `at::_ops::add_Tensor::call`.
Test Plan: Imported from OSS
Reviewed By: ezyang
Differential Revision: D28833086
Pulled By: bdhirsh
fbshipit-source-id: 55f322a8378cb9a3cb6642f72aa291be381dd95b
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58570
**What the PR does**
Generate a fast-path `at::meta::{op}` API for calling meta functions without having to go through the dispatcher. This will be important for perf for external backends that want to use meta functions for shape checking (which seems likely to be what we end up doing for LazyTensorCore).
**Details**
In order to avoid naming collisions I had to make two small changes:
- rename `MetaFunctions.h` template -> `NativeMetaFunctions.h` (this is the file that declares the impl() function for every structured operator).
- rename the meta class: `at::meta::{op}::meta()` -> `at::meta::structured_{op}::meta()`
I also deleted a few unnecessary includes, since any file that includes NativeFunctions.h will automatically include NativeMetaFunctions.h.
**Why I made the change**
This change isn't actually immediately used anywhere; I already started writing it because I thought it would be useful for structured composite ops, but that isn't actually true (see [comment](https://github.com/pytorch/pytorch/pull/58266#issuecomment-843213147)). The change feels useful and unambiguous though so I think it's safe to add. I added explicit tests for C++ meta function calls just to ensure that I wrote it correctly - which is actually how I hit the internal linkage issue in the PR below this in the stack.
Test Plan: Imported from OSS
Reviewed By: pbelevich
Differential Revision: D28711299
Pulled By: bdhirsh
fbshipit-source-id: d410d17358c2b406f0191398093f17308b3c6b9e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58569
This should allow external C++ files that aren't compiled into `libtorch.so`/`libtorch_cpu.so` (including all of fbcode) to use fast path functions like `at::cpu::add()`, which skip the dispatcher.
So, after spending way too much time trying to figure out why I was getting linker errors when calling `at::meta::{op}` and `at::cpu::{op}` from C++ test files, I realized that we're not including the header files for C++ for the namespaced operator definitions. I.e. `RegisterCPU.cpp`, which provides definitions for the `at::cpu::{op}` fast path functions, wasn't including the `CPUFunctions.h` header.
Why that breaks stuff: the `CPUFunctions.h` header file is what marks each function with the `TORCH_API` macro, so without including it, when we build `libtorch.so` and `libtorch_cpu.so`, the compiler will look at the definition in `RegisterCPU.cpp`, not see a `TORCH_API`, and decide that the function should get internal linkage.
An alternative would be to directly mark the function definitions in `RegisterCPU.cpp` with `TORCH_API`, but this seemed cleaner.
Test Plan: Imported from OSS
Reviewed By: pbelevich
Differential Revision: D28711300
Pulled By: bdhirsh
fbshipit-source-id: 535f245c20e977ff566d6da0757b3cefa137040b
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59850
This whole stack does not change anything to the codegened code
Test Plan: Imported from OSS
Reviewed By: ailzhang
Differential Revision: D29063816
Pulled By: albanD
fbshipit-source-id: ca3067443d8e6282c1077d3dafa3b4f330d43b28
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59849
This whole stack does not change anything to the codegened code
Test Plan: Imported from OSS
Reviewed By: ailzhang
Differential Revision: D29063815
Pulled By: albanD
fbshipit-source-id: c4baa72594bd2fe50ac67f513916f2b2ccb7488c
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59848
This whole stack does not change anything to the codegened code
Test Plan: Imported from OSS
Reviewed By: ailzhang
Differential Revision: D29063818
Pulled By: albanD
fbshipit-source-id: c68734672eeacd212d7bd9bebe3d53aaa20c3c24
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59847
This whole stack does not change anything to the codegened code
Test Plan: Imported from OSS
Reviewed By: ailzhang
Differential Revision: D29063817
Pulled By: albanD
fbshipit-source-id: 284c3e057029b7a67f43a1b034bb30863bd68c71
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59096
RegisterBackendSelect is bringing in ~100 extra ops to the runtime. This messes with the compatibility api, and also adds a nontrivial amount of size.
Test Plan: Model Unittests/CI
Reviewed By: iseeyuan
Differential Revision: D28588100
fbshipit-source-id: ffd0b5b9cbe20f27dbf3be418a6c1f80c7396fdb
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59018Fixes#58044.
This PR:
- adds `ATEN_FN(op)` and `ATEN_FN2(op, overload)` macros that resolve to
an non-overloaded function in aten::_ops that calls the desired operator
(without default arguments).
The motivation for this is two-fold:
1) Using aten operators with templates is hard if the operator is
overloaded (e.g. add.Tensor and add.Scalar).
2) Method-only operators require special handling; pointers-to-method
are different from function pointers. `ATEN_FN2(add_, Tensor)` returns
a function instead of a method.
There is some interesting behavior for out= operations.
`ATEN_FN2(sin, "out")` gives a function that is *faithful* to the schema;
that is, the order of arguments is exactly what it looks like in the
schema. This makes it so that you can directly register
`ATEN_FN2(sin,"out")` (or a function wrapping it using the same signature)
as an override for a DispatchKey.
Test Plan:
- New tests that ATEN_FN2 works on function and method-only operators
- New test that ATEN_FN works
- New test that ATEN_FN macro returns a "faithful" function.
Codegen output:
Operators.h and Operators.cpp are both here:
https://gist.github.com/zou3519/c2c6a900410b571f0d7d127019ca5175
Reviewed By: bdhirsh
Differential Revision: D28721206
Pulled By: zou3519
fbshipit-source-id: a070017f98e8f4038cb0c64be315eef45d264217
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58092Fixes#58044.
This PR:
- adds `ATEN_FN(op)` and `ATEN_FN2(op, overload)` macros that resolve to
an non-overloaded function in aten::_ops that calls the desired operator
(without default arguments).
The motivation for this is two-fold:
1) Using aten operators with templates is hard if the operator is
overloaded (e.g. add.Tensor and add.Scalar).
2) Method-only operators require special handling; pointers-to-method
are different from function pointers. `ATEN_FN2(add_, Tensor)` returns
a function instead of a method.
There is some interesting behavior for out= operations.
`ATEN_FN2(sin, "out")` gives a function that is *faithful* to the schema;
that is, the order of arguments is exactly what it looks like in the
schema. This makes it so that you can directly register
`ATEN_FN2(sin,"out")` (or a function wrapping it using the same signature)
as an override for a DispatchKey.
Test Plan:
- New tests that ATEN_FN2 works on function and method-only operators
- New test that ATEN_FN works
- New test that ATEN_FN macro returns a "faithful" function.
Codegen output:
Operators.h and Operators.cpp are both here:
https://gist.github.com/zou3519/c2c6a900410b571f0d7d127019ca5175
Reviewed By: mruberry
Differential Revision: D28643215
Pulled By: zou3519
fbshipit-source-id: 7b2b8459f1b2eb5ad01ee7b0d2bb77639f77940e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58889
fixes https://github.com/pytorch/pytorch/issues/58796
Planning on re-testing locally tomorrow morning to confirm, but this change should fix the non-determinism in the codegen output that was causing `ccache` not to re-use its cached output.
I built from the commit referenced in https://github.com/pytorch/pytorch/issues/58796 a few times and ran `diff -Naur` on the codegen output in `build/aten/src/ATen`. After a few tries, `NativeFunctions.h` had a few diffs. The diffs were all related to the ordering of functional/inplace/out variants of a NativeFunctionGroup, which looked non-deterministic.
That looks like it's coming from my calling `set()` to filter out duplicate NativeFunction declarations. The earlier version of the codegen also called `set()` to filter out duplicates, but it did so individually for each `NativeFunction` object, before merging the groups (I'm not too sure why this didn't introduce non-determinism before. though). With the refactor from https://github.com/pytorch/pytorch/pull/57361, we're calling `set()` on the declarations from every operator for a given DispatchKey, which is probably what introduced the nondeterminism.
Test Plan: Imported from OSS
Reviewed By: gchanan
Differential Revision: D28675941
Pulled By: bdhirsh
fbshipit-source-id: bb66de00aafeeb9720d85e8156ac9f7539aed0d6
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57510
This is a re-write of https://github.com/pytorch/pytorch/pull/56835, which is significantly shorter thanks to the data model change in the PR below this one in the stack. See the original description in the linked PR for details.
The functional changes in this PR are the same as in the above linked one, so the description is the same with a few small changes:
- I don't bother generating `at::xla::{op}` entries for CPU fallbacks. After looking around, I see precedent for that. For example, we don't have `at::cpu::{op}` entries for composite ops- if you really want to bypass the dispatcher you need to call `at::compositeimplicitautograd::{op}`. Maybe we should revisit that later if we find an important use case for having full namespace coverage, but that doesn't seem worth half-fixing for external backends in this PR.
Test Plan: Imported from OSS
Reviewed By: navahgar
Differential Revision: D28474364
Pulled By: bdhirsh
fbshipit-source-id: 4d58b60e5debad6f1ff06420597d8df8505b2876
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57361
Data model change in the codegen, which splits backend-specific information out of `NativeFunction`
### Overview
Currently in the codegen, native_functions.yaml has backend-specific information about each operator that is encoded directly into the data model, in the `NativeFunction` object. That's reasonable, since the native_functions.yaml is the source of truth for information about an operator, and the data model encodes that information into types.
Now that external backends can use the codegen though, that information is technically incomplete/inaccurate. In another PR, I tried patching the information on the `NativeFunction` object with the additional external information, by updating the `dispatch` entry to contain the external backend kernel name and dispatch key.
Instead, this PR tries to split out that information. The `NativeFunction` class contains all information about an operator from native_functions.yaml that's backend-independent and is known never to change regardless of what extra information backends provide. We also build up a backend "index", which is basically a mapping from [backend] -> [backend-specific-metadata]. Reading in an external backend yaml just involves updating that index with the new backend.
There were a few places where `NativeFunction` used the dispatch table directly, that I encoded as properties directly on the NativeFunction object (e.g. `is_abstract`). They were mostly around whether or not the operator has a composite kernel, which isn't something that's going to change for any external backends.
This has a few advantages:
- We can more easily re-use the existing logic in `native_function.py` and `register_dispatch_key.py` for both native and external backends, since they both involve a NativeFunction + a particular backend index
- The data in the data model will be the same regardless of how the codegen is run. Running the codegen with a new external backend doesn't change the data inside of NativeFunction or an existing backend index. It just adds a new index for that backend.
- There are several of codegen areas that don't care about backend-specific information: mostly the tracing and autograd codegen. We can reason about the codegen there more easily, knowing that backend-specific info is entirely uninvolved.
An alternative to this split would be to augment the NativeFunction objects with external backend information at the time that we create them. So the external codegen could read both native_functions.yaml and the external backend's yaml at the same time, and construct a NativeObject with a full dispatch table (including the XLA entry), and the correct setting of structured (taking into account both yamls). One disadvantage to this approach is that NativeFunction objects now contain different stuff depending on how you ran the codegen, and you have to make sure that any changes to the codegen can properly handle all the different variants.
### Data Model Changes
Removed 3 classes, which are used by the external codegen:
- ExternalBackendFunction
- ExternalBackendFunctionsGroup
- ExternalBackendMetadata
And added two new ones:
- BackendIndex
- BackendMetadata
`BackendIndex` contains any info that's specific to that backend, plus a mapping from operator names to backend specific metadata about the operator. One example of backend-specific info that's not operator-dependent is the fact that XLA prefers to implement functional kernels instead of out kernels (and so when they eventually mark an op as structured, they're going to mark the functional op and not the out op).
`BackendMetadata` contains info specific to an (operator, backend) pair. Right now, that's just (a) the name of the kernel, and (b) whether or not that operator is structured.
### Questions
I wanted to get this PR up earlier so I could get feedback, but there are a few things I want to call out:
**Dealing with `structured`.**
This PR separates out the notion of `structured` into two bits of information:
- Does [operator] have a meta() function. This is backend-agnostic, and is represented by the `structured` property on `NativeFunction`, same as before. This is used, e.g., to decide what signatures to add to `MetaFunctions.h`.
- Does [operator, backend] have an impl() function. This is backend dependent; even though technically all in-tree backends are forced to write impl() functions for an operator when we port the op to structured in native_functions.yaml, out-of-tree backends can decide to opt in independently. This is represented as a property on `BackendMetadata`. This is used in most other cases, e.g. in `RegisterDispatchKey` when we're deciding whether or not to gen a structured or unstructured wrapper.
I also baked `is_structured_dispatch_key` directly into each BackendIndex. So for operators marked "structured" in native_functions.yaml, their corresponding CPU/CUDA BackendIndex entries will be marked structured, and all others (except for potentially external backends) will not.
I ended up trying to deal with `structured` in this change since it's technically backend dependent (XLA can opt kernels into structured separately from in-tree ops), but that may have been too ambitious: it's technically not relevant until we actually add support for structured external kernels. If it's not clear that this is the right path for dealing with structured and we want to push that off, I'm fine with backing out the bits of this PR that make `structured` backend-dependent. I don't see anything *too* controversial related to structured in the change, but I tried to call out any areas in the comments
**Localizing the fact that external backends follow Dispatcher convention.**
Another thing that's sort of backend specific that I didn't totally address in this PR is the fact the fact that in-tree backends follow the Native API while external backends follow the Dispatcher API. I painted over that in `native_functions.py` by adding a helper, `kernel_signature`, that takes in a native function and gives you the "correct" signature for the specified backend- NativeSignature for in-tree backends, and DispatcherSignature for out-of-tree backends. In order to make that fully useable though, we'll need `NativeSignature` and `DispatcherSignature` to have matching interfaces. I didn't bother with that in this PR, which is why `gen_external_aten_fallbacks.py` still has a bunch of direct references to the dispatcher API. Thinking of adding it in a later PR but wanted to see if anyone has other opinions.
Maybe `is_external()` shouldn't even be a property on the BackendMetadata, and anything the codegen does that requires asking for that information should just be better abstracted away.
**Thoughts on the `BackendIndex` / `BackendMetadata` breakdown.**
One thing that's annoying right now is that to query for various pieces of metadata, you call helper functions like `backend_index.structured(f)`, which queries that particular backend and tells you if that specific NativeFunctionGroup is structured for that backend. It has to return an `Optional[bool]` though, since you have to handle the case where that operator doesn't have a kernel for that backend at all. So users of those helpers end up with a bunch of optionals that they need to unpack, even if they know at some point that the result isn't None. I think it would be easier instead to just store the NativeFunction object as a field directly on the BackendMetadata. Curious if there are any other opinions on a better way to model it though.
Test Plan: Imported from OSS
Reviewed By: navahgar
Differential Revision: D28474362
Pulled By: bdhirsh
fbshipit-source-id: 41a00821acf172467d764cb41e771e096542f661
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56601
Updating it to ensure that RegistrationDeclarations.yaml is completely
unchanged
This reverts commit 90e532f3ef.
Test Plan: Imported from OSS
Reviewed By: ailzhang
Differential Revision: D27915305
Pulled By: bdhirsh
fbshipit-source-id: 491a025c44221690dad849f9a2166934130c0fec
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55351
We incorrectly used `Tensor&` to mean "the underlying
TensorImpl cannot be changed", as explained in
https://github.com/zdevito/ATen/issues/27#issuecomment-330717839 .
This diff gets us on the path to fixing this problem: we have an
incremental way to fix individual native functions so that we can
apply any handwritten fixes a few at a time. It gets the migration
started with the `resize` family of native functions.
ghstack-source-id: 127092677
Test Plan: fitsships
Reviewed By: ezyang
Differential Revision: D27583983
fbshipit-source-id: 4eeeec85f5d268e9d0f1645eb9396914a9f9557f
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56307
This should fix https://github.com/pytorch/pytorch/issues/56273. I tested these changes locally by making them directly on top of https://github.com/pytorch/pytorch/pull/56151, and running the xla tests (`xla/test/cpp/build/test_ptxla`).
**Current state:** For ops that are ported to structured, If external backends like XLA have implemented the `out` op but not the `functional` version, they will call into our code-generated `CompositeExplicitAutograd` kernel, which calls the structured operator's `meta()` function and then redispatches to the external backend's `out` function.
If XLA has registered their own kernel to the `functional` variant of the op, it'll override our codegen'd composite kernel. XLA has logic to code-generate "CPU fallback" kernels for "required" ops. It gets this information based off of `RegistrationDeclarations.yaml`. That info was technically incorrect up until this PR, since we were code-generating `inplace/functional` composite kernels for structured ops, but not updating `RegistrationDeclarations.yaml` with that information.
Test Plan: Imported from OSS
Reviewed By: bhosmer
Differential Revision: D27883950
Pulled By: bdhirsh
fbshipit-source-id: fe896b0d2bbd4369490dcdf7a87f227fd3d8b8b3
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55047
Added namespaces to all of the `CTypes` printed in the codegen. This is pretty much required if we want to use codegen externally, since we can no longer assume that we're inside of the `at::` namespace.
Important changes are in `types.py`.
How do we add the notion of namespaces to C++ types without people having to write "at::Tensor" everywhere? Before this PR, `CType` held a raw string representing the type, i.e. `BaseCType("Tensor", binds)`. This PR introduces a set of singleton base C++ types in `types.py`, that know how to print their namespace. Instead, we'd write `BaseCType(tensorT, binds)`, where printing `tensorT` will properly print out "at::Tensor".
This also means that you can't create arbitrary `CTypes`. If we need a new C++ type in the codegen, we need to add it to the list in `types.py`.
One blip in the design: we don't want to change `RegistrationDeclarations.yaml`, since that'll break external backends that ingest it. I added separate functions to display types without the namespace that are used to create RegistrationDeclarations.yaml`. With an external codegen API though, we can eventually kill it :)
I also didn't realize until this PR that `Declarations.yaml` is still directly in use, by some python/autograd codegen. Rather than keep that yaml byte-for-byte compatible, I just updated the callsites in the autograd codegen to work with namespaces. In the NEXT pr, I try to clean up some of the autograd codegen to stop using raw strings to match against C++ types.
Test Plan: Imported from OSS
Reviewed By: bhosmer
Differential Revision: D27708349
Pulled By: bdhirsh
fbshipit-source-id: 56a4f81fc101795bcb9ee1f722121480fb2356ad
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55046
Updating `returns` in the codegen to return a CType instead of a raw string.
This has benefit of putting all stringifying logic through CType, which is useful in the followup PR when I add namespaces.
I also added new CTypes for other templated C++ types: array, vector and tuple. Mostly because it makes the namespacing logic in the next PR significantly easier. It also seems more natural to me that `BaseCType` shouldn't represent specializations of templated types.
There's a little bit of weirdness, types that are currently *only* used for returns, i.e. `TupleCType`. Returns aren't named, so I opted not to give it one- so we can add it in later if we discover that we need it.
Test Plan: Imported from OSS
Reviewed By: bhosmer
Differential Revision: D27708348
Pulled By: bdhirsh
fbshipit-source-id: 230b210c3e53be1bd362105fbea8451055dc59a8
Summary:
Generally wildcard imports are bad for the reasons described here: https://www.flake8rules.com/rules/F403.html
This PR replaces wildcard imports with an explicit list of imported items where possible, and adds a `# noqa: F403` comment in the other cases (mostly re-exports in `__init__.py` files).
This is a prerequisite for https://github.com/pytorch/pytorch/issues/55816, because currently [`tools/codegen/dest/register_dispatch_key.py` simply fails if you sort its imports](https://github.com/pytorch/pytorch/actions/runs/742505908).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55838
Test Plan: CI. You can also run `flake8` locally.
Reviewed By: jbschlosser
Differential Revision: D27724232
Pulled By: samestep
fbshipit-source-id: 269fb09cb4168f8a51fd65bfaacc6cda7fb87c34
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54470
```
git grep -l 'DefaultBackend' | xargs sed -i 's/DefaultBackend/CompositeExplicitAutograd/g'
```
Plus a quick fixup in native/README.md
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Differential Revision: D27253240
Pulled By: ezyang
fbshipit-source-id: 964df951ea8b52fa72937f3cc66aeaf49a702e6f
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54466
I had to very carefully audit all the use sites since there are a lot
of other uses of the string Math; I did most of the conversion by
grepping for all occurrences of Math and then doing a search
replace.
I also updated documentation for clarity.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Reviewed By: ngimel
Differential Revision: D27253239
Pulled By: ezyang
fbshipit-source-id: afb485d07ff39575742a4f0e1e205179b60bc953
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54427
A StructuredNativeFunctions is no longer guaranteed to actually
be structured (test structured property for that), so we rename
this to a more neutral name.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Reviewed By: ailzhang
Differential Revision: D27235380
Pulled By: ezyang
fbshipit-source-id: 2b438d615bf06a47fc9c7bf6eb66fd8b4df31bc8
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54426
Previously, we only put NativeFunctions in StructuredNativeFunctions
if the out variant advertised that the kernel was structured. However,
there are a few code generation things that can take advantage of
this trio structure, even if the kernel itself hasn't been ported
to be structured. So better to always group things when they are
related, and then let clients decide whether or not to use the
structure or throw it away.
While doing this, I had hoped that there weren't any functional/inplace
pairs that didn't also have an out variant. This turned out to not
be true. These are probably all oversights and should get fixed at
some point.
Bill of changes:
- The actual operational change happens in
StructuredNativeFunctions.from_dict; then I need to relax some
__post_init__ invariants. To tell if a StructuredNativeFunctions
is actually structured, there is a new structured property, which
is queried from a few new locations in code
- Refactor native_functions.py into gen_structured/gen_unstructured
functions so I can easily call gen_unstructured from two contexts
I intend to s/StructuredNativeFunctions/NativeFunctionsGroup/ but
for ease of review this rename hasn't been done in this PR.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Reviewed By: ailzhang
Differential Revision: D27235379
Pulled By: ezyang
fbshipit-source-id: d8a15de9abb75b365348ab94e67b830704e30cf0
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54419
I'm planning to break it into some helper functions, so let's put it in its own module first.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Reviewed By: ailzhang
Differential Revision: D27235378
Pulled By: ezyang
fbshipit-source-id: c03c5440d2d753859e2c5ec2b2c8b1b82870f03a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53859
The redispatch API wasn't linking properly when static dispatch is enabled. I'm still not sure why this wasn't caught by the static dispatch test in CI- maybe, as swolchok pointed out, we have a flag set somewhere that defers undefined symbols until runtime.
Before, building with static dispatch enabled locally + running `import torch` gave me this error:
```
>>> import torch
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/raid/hirsheybar/pytorch/torch/__init__.py", line 197, in <module>
from torch._C import *
ImportError: /raid/hirsheybar/pytorch/torch/lib/libtorch_cpu.so: undefined symbol: _ZN2at10redispatch11logical_or_EN3c1014DispatchKeySetERNS_6TensorERKS3_
>>>
```
Printing the symbol:
```
(pytorch) hirsheybar@devfair017:/scratch/hirsheybar/pytorch$ c++filt _ZN2at10redispatch11logical_or_EN3c1014DispatchKeySetERNS_6TensorERKS3_
at::redispatch::logical_or_(c10::DispatchKeySet, at::Tensor&, at::Tensor const&)
```
Sure enough, the functions defined in `RedispatchFunctions.cpp` don't have the DispatchKeySet argument included. Adding them in this PR.
Test Plan: Imported from OSS
Reviewed By: ljk53
Differential Revision: D26998735
Pulled By: bdhirsh
fbshipit-source-id: c6c1104e42d13b7ec9d964b7e08d2adc8b344b78