mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
ce7910ba81
13 Commits
| Author | SHA1 | Message | Date | |
|---|---|---|---|---|
|
|
ce7910ba81 |
ufunc codegen (#65851)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65851 Design doc: https://docs.google.com/document/d/12rtlHnPUpaJ-I52Iob3L0WA3rKRr_OY7fXqeCvn2MVY/edit First read the design doc to understand the user syntax. In this PR, we have converted add to use ufunc codegen; most of the cpp changes are deleting the preexisting implementations of add, and ufunc/add.h are the new implementations in the ufunc format. The bulk of this PR is in the new codegen machinery. Here's the order to read the files: * `tools/codegen/model.py` * Some self-explanatory utility classes: `ScalarType`, `DTYPE_CLASSES` * New classes for representing ufunc entries in `native_functions.yaml`: `UfuncKey` and `UfuncInnerLoop`, as well as parsing logic for these entries. UfuncKey has some unusual entries (e.g., CPUScalar) that don't show up in the documentation, more on these below). * A predicate `is_ufunc_dispatch_key` for testing which dispatch keys should get automatically generated when an operator opts into ufuncs (CPU and CUDA, for now!) * `tools/codegen/api/types.py` * More self-explanatory utility stuff: ScalarTypeToCppMapping mapping ScalarType to CppTypes; Binding.rename for changing the name of a binding (used when we assign constructor variables to member variables inside CUDA functors) * New VectorizedCType, representing `at::vec::Vectorized<T>`. This is used inside vectorized CPU codegen. * New `scalar_t` and `opmath_t` BaseCppTypes, representing template parameters that we work with when doing codegen inside ufunc kernel loops (e.g., where you previously had Tensor, now you have `scalar_t`) * `StructuredImplSignature` represents a `TORCH_IMPL_FUNC` definition, and straightforwardly follows from preexisting `tools.codegen.api.structured` * `tools/codegen/translate.py` - Yes, we use translate a LOT in this PR. I improved some of the documentation, the only substantive changes are adding two new conversions: given a `scalar_t` or a `const Scalar&`, make it convertible to an `opmath_t` * `tools/codegen/api/ufunc.py` * OK, now we're at the meaty stuff. This file represents the calling conventions of three important concepts in ufunc codegen, which we'll describe shortly. All of these APIs are relatively simple, since there aren't any complicated types by the time you get to kernels. * stubs are the DispatchStub trampolines that CPU kernels use to get to their vectorized versions. They drop all Tensor arguments (as they are in TensorIterator) but otherwise match the structured calling convention * ufuncs are the inner loop template functions that you wrote in ufunc/add.h which do the actual computation in question. Here, all the Tensors and Scalars have been converted into the computation type (`opmath_t` in CUDA, `scalar_t` in CPU) * ufunctors are a CUDA-only concept representing functors that take some of their arguments on a host-side constructor, and the rest in the device-side apply. Once again, Tensors and Scalars are converted into the computation type, `opmath_t`, but for clarity all the functions take `scalar_t` as argument (as this is the type that is most salient at the call site). Because the constructor and apply are code generated separately, `ufunctor_arguments` returns a teeny struct `UfunctorBindings` * `tools/codegen/dest/ufunc.py` - the workhorse. This gets its own section below. * `tools/codegen/gen.py` - just calling out to the new dest.ufunc implementation to generate UfuncCPU_add.cpp, UFuncCPUKernel_add.cpp and UfuncCUDA_add.cu files per ufunc operator. Each of these files does what you expect (small file that registers kernel and calls stub; CPU implementation; CUDA implementation). There is a new file manager for UFuncCPUKernel files as these need to get replicated by cmake for vectorization. One little trick to avoid recompilation is we directly replicate code generated forward declarations in these files, to reduce the number of headers we depend on (this is codegen, we're just doing the preprocessors job!) * I'll talk about build system adjustments below. OK, let's talk about tools/codegen/dest/ufunc.py. This file can be roughly understood in two halves: one for CPU code generation, and the other for CUDA code generation. **CPU codegen.** Here's roughly what we want to generate: ``` // in UfuncCPU_add.cpp using add_fn = void (*)(TensorIteratorBase&, const at::Scalar&); DECLARE_DISPATCH(add_fn, add_stub); DEFINE_DISPATCH(add_stub); TORCH_IMPL_FUNC(ufunc_add_CPU) (const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha, const at::Tensor& out) { add_stub(device_type(), *this, alpha); } // in UfuncCPUKernel_add.cpp void add_kernel(TensorIteratorBase& iter, const at::Scalar& alpha) { at::ScalarType st = iter.common_dtype(); RECORD_KERNEL_FUNCTION_DTYPE("add_stub", st); switch (st) { AT_PRIVATE_CASE_TYPE("add_stub", at::ScalarType::Bool, bool, [&]() { auto _s_alpha = alpha.to<scalar_t>(); cpu_kernel(iter, [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); }); }) AT_PRIVATE_CASE_TYPE( "add_stub", at::ScalarType::ComplexFloat, c10::complex<float>, [&]() { auto _s_alpha = alpha.to<scalar_t>(); auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha); cpu_kernel_vec( iter, [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); }, [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); }); }) ... ``` The most interesting change about the generated code is what previously was an `AT_DISPATCH` macro invocation is now an unrolled loop. This makes it easier to vary behavior per-dtype (you can see in this example that the entry for bool and float differ) without having to add extra condtionals on top. Otherwise, to generate this code, we have to hop through several successive API changes: * In TORCH_IMPL_FUNC(ufunc_add_CPU), go from StructuredImplSignature to StubSignature (call the stub). This is normal argument massaging in the classic translate style. * In add_kernel, go from StubSignature to UfuncSignature. This is nontrivial, because we must do various conversions outside of the inner kernel loop. These conversions are done by hand, setting up the context appropriately, and then the final ufunc call is done using translate. (BTW, I introduce a new convention here, call on a Signature, for code generating a C++ call, and I think we should try to use this convention elsewhere) The other piece of nontrivial logic is the reindexing by dtype. This reindexing exists because the native_functions.yaml format is indexed by UfuncKey: ``` Generic: add (AllAndComplex, BFloat16, Half) ScalarOnly: add (Bool) ``` but when we do code generation, we case on dtype first, and then we generate a `cpu_kernel` or `cpu_kernel_vec` call. We also don't care about CUDA code generation (which Generic) hits. Do this, we lower these keys into two low level keys, CPUScalar and CPUVector, which represent the CPU scalar and CPU vectorized ufuncs, respectively (Generic maps to CPUScalar and CPUVector, while ScalarOnly maps to CPUScalar only). Reindexing then gives us: ``` AllAndComplex: CPUScalar: add CPUVector: add Bool: CPUScalar: add ... ``` which is a good format for code generation, but too wordy to force native_functions.yaml authors to write. Note that when reindexing, it is possible for there to be a conflicting definition for the same dtype; we just define a precedence order and have one override the other, so that it is easy to specialize on a particular dtype if necessary. Also note that because CPUScalar/CPUVector are part of UfuncKey, technically you can manually specify them in native_functions.yaml, although I don't expect this functionality to be used. **CUDA codegen.** CUDA code generation has many of the same ideas as CPU codegen, but it needs to know about functors, and stubs are handled slightly differently. Here is what we want to generate: ``` template <typename scalar_t> struct CUDAFunctorOnSelf_add { using opmath_t = at::opmath_type<scalar_t>; opmath_t other_; opmath_t alpha_; CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {} __device__ scalar_t operator()(scalar_t self) { return ufunc::add(static_cast<opmath_t>(self), other_, alpha_); } }; ... two more functors ... void add_kernel(TensorIteratorBase& iter, const at::Scalar & alpha) { TensorIteratorBase& iter = *this; at::ScalarType st = iter.common_dtype(); RECORD_KERNEL_FUNCTION_DTYPE("ufunc_add_CUDA", st); switch (st) { AT_PRIVATE_CASE_TYPE("ufunc_add_CUDA", at::ScalarType::Bool, bool, [&]() { using opmath_t = at::opmath_type<scalar_t>; if (false) { } else if (iter.is_cpu_scalar(1)) { CUDAFunctorOnOther_add<scalar_t> ufunctor( iter.scalar_value<opmath_t>(1), (alpha).to<opmath_t>()); iter.remove_operand(1); gpu_kernel(iter, ufunctor); } else if (iter.is_cpu_scalar(2)) { CUDAFunctorOnSelf_add<scalar_t> ufunctor( iter.scalar_value<opmath_t>(2), (alpha).to<opmath_t>()); iter.remove_operand(2); gpu_kernel(iter, ufunctor); } else { gpu_kernel(iter, CUDAFunctor_add<scalar_t>((alpha).to<opmath_t>())); } }) ... REGISTER_DISPATCH(add_stub, &add_kernel); TORCH_IMPL_FUNC(ufunc_add_CUDA) (const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha, const at::Tensor& out) { add_kernel(*this, alpha); } ``` The functor business is the bulk of the complexity. Like CPU, we decompose CUDA implementation into three low-level keys: CUDAFunctor (normal, all CUDA kernels will have this), and CUDAFunctorOnOther/CUDAFunctorOnScalar (these are to support Tensor-Scalar specializations when the Scalar lives on CPU). Both Generic and ScalarOnly provide ufuncs for CUDAFunctor, but for us to also lift these into Tensor-Scalar specializations, the operator itself must be eligible for Tensor-Scalar specialization. At the moment, this is hardcoded to be all binary operators, but in the future we can use tags in native_functions.yaml to disambiguate (or perhaps expand codegen to handle n-ary operators). The reindexing process not only reassociates ufuncs by dtype, but it also works out if Tensor-Scalar specializations are needed and codegens the ufunctors necessary for the level of specialization here (`compute_ufunc_cuda_functors`). Generating the actual kernel (`compute_ufunc_cuda_dtype_body`) just consists of, for each specialization, constructing the functor and then passing it off to `gpu_kernel`. Most of the hard work is in functor generation, where we take care to make sure `operator()` has the correct input and output types (which `gpu_kernel` uses to arrange for memory accesses to the actual CUDA tensor; if you get these types wrong, your kernel will still work, it will just run very slowly!) There is one big subtlety with CUDA codegen: this won't work: ``` Generic: add (AllAndComplex, BFloat16, Half) ScalarOnly: add_bool (Bool) ``` This is because, even though there are separate Generic/ScalarOnly entries, we only generate a single functor to cover ALL dtypes in this case, and the functor has the ufunc name hardcoded into it. You'll get an error if you try to do this; to fix it, just make sure the ufunc is named the same consistently throughout. In the code, you see this because after testing for the short circuit case (when a user provided the functor themselves), we squash all the generic entries together and assert their ufunc names are the same. Hypothetically, if we generated a separate functor per dtype, we could support differently named ufuncs but... why would you do that to yourself. (One piece of nastiness is that the native_functions.yaml syntax doesn't stop you from shooting yourself in the foot.) A brief word about CUDA stubs: technically, they are not necessary, as there is no CPU/CPUKernel style split for CUDA kernels (so, if you look, structured impl actually calls add_kernel directly). However, there is some code that still makes use of CUDA stubs (in particular, I use the stub to conveniently reimplement sub in terms of add), so we still register it. This might be worth frying some more at a later point in time. **Build system changes.** If you are at FB, you should review these changes in fbcode, as there are several changes in files that are not exported to ShipIt. The build system changes in this patch are substantively complicated by the fact that I have to implement these changes five times: * OSS cmake build * OSS Bazel build * FB fbcode Buck build * FB xplat Buck build (selective build) * FB ovrsource Buck build Due to technical limitations in the xplat Buck build related to selective build, it is required that you list every ufunc header manually (this is done in tools/build_variables.bzl) The OSS cmake changes are entirely in cmake/Codegen.cmake there is a new set of files cpu_vec_generated (corresponding to UfuncCPUKernel files) which is wired up in the same way as other files. These files are different because they need to get compiled multiple times under different vectorization settings. I adjust the codegen, slightly refactoring the inner loop into its own function so I can use different base path calculation depending on if the file is traditional (in the native/cpu folder) or generated (new stuff from this diff. The Bazel/Buck changes are organized around tools/build_variables.bzl, which contain the canonical list of ufunc headers (aten_ufunc_headers), and tools/ufunc_defs.bzl (added to ShipIt export list in D34465699) which defines a number of functions that compute the generated cpu, cpu kernel and cuda files based on the headers list. For convenience, these functions take a genpattern (a string with a {} for interpolation) which can be used to easily reformat the list of formats in target form, which is commonly needed in the build systems. The split between build_variables.bzl and ufunc_defs.bzl is required because build_variables.bzl is executed by a conventional Python interpreter as part of the OSS cmake, but we require Skylark features to implement the functions in ufunc_defs.bzl (I did some quick Googling but didn't find a lightweight way to run the Skylark interpreter in open source.) With these new file lists, the rest of the build changes are mostly inserting references to these files wherever necessary; in particular, cpu kernel files have to be worked into the multiple vectorization build flow (intern_build_aten_ops in OSS Bazel). Most of the subtlety relates to selective build. Selective build requires operator files to be copied per overall selective build; as dhruvbird explains to me, glob expansion happens during the action graph phase, but the selective build handling of TEMPLATE_SOURCE_LIST is referencing the target graph. In other words, we can't use a glob to generate deps for another rule, because we need to copy files from wherever (included generated files) to a staging folder so the rules can pick them up. It can be somewhat confusing to understand which bzl files are associated with which build. Here are the relevant mappings for files I edited: * Used by everyone - tools/build_tools.bzl, tools/ufunc_defs.bzl * OSS Bazel - aten.bzl, BUILD.bazel * FB fbcode Buck - TARGETS * FB xplat Buck -BUCK, pt_defs.bzl, pt_template_srcs.bzl * FB ovrsource Buck - ovrsource_defs.bzl, pt_defs.bzl Note that pt_defs.bzl is used by both xplat and ovrsource. This leads to the "tiresome" handling for enabled backends, as selective build is CPU only, but ovrsource is CPU and CUDA. BTW, while I was at it, I beefed up fb/build_arvr.sh to also do a CUDA ovrsource build, which was not triggered previously. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D31306586 Pulled By: ezyang fbshipit-source-id: 210258ce83f578f79cf91b77bfaeac34945a00c6 (cherry picked from commit d65157b0b894b6701ee062f05a5f57790a06c91c) |
||
|
|
0032fa7725 |
Add a Functionalization pass in core (#64432)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64432 Original PR description + feedback here: https://github.com/pytorch/pytorch/pull/63048 I've addressed all of the feedback in the original PR and made some pretty large changes, listed below. **Table of Contents** - Starting points - List of the main changes from the original PR - Next Steps - Example codegen output (for a view, mutation, and view+mutation op) **Starting Points** A good place to start when looking through the PR: * Alban mentioned that this is a useful mental model (thanks Ed for originally making this clear to me). Semantically, the pass currently does THREE things, which are all needed by functorch - all fused together into one big pass. * (a) alias removal, which replaces {view} calls with {view}_copy calls, and manually tracks aliasing information, so that when one tensor is mutated, we re-apply the same mutation to all of the aliases. This is the bulk of the work - once this is done, the next 2 things are trivial to implement. * (b) mutation removal, which is easy to do once we know that there are no aliases. Every mutation `a.add_(b)` becomes `a.replace_(a.add(b))` * (c) reapplying views: all of the `{view}_copy` calls are replaced with `{view}` calls again. This is an optimization that we can make specifically for functorch (and strided backends), that only care about mutation removal and not alias removal * XLA and Vulkan only want (a), or (a) + (b). Later, we'll want to split this out so that you can actually opt into different versions of this logic. * There is currently no {view}_copy replacement, because the pass just <replace views with copies> and <replace copies with views> steps have been combined. Later, we'll want to actually implement {view}_copy variants of each view operator, probably with codegen. * documentation breadcrumb 1, in `FunctionalTensorWrapper.cpp`: https://github.com/pytorch/pytorch/pull/64432/files#diff-a0bac99bf205dba5b94cb64fc2466d3d55d991887572f9cd6a02e27b3a91dd60R59 (you might have to expand the `FunctionalTensorWrapper.cpp` file, which GitHub closes by default because it's large) * documentation breadcrumb 2, in `FunctionalTensorWrapper.h`: https://github.com/pytorch/pytorch/pull/64432/files#diff-c945c71a4ccac65871f24a912e8904f9a5088b24a32e636727ea9c8fe920708aR12 * Reading through the codegen output at the bottom of this description. **Main changes from the original PR** (1) I use lambdas instead of a giant enum to handle all of the different views. This results in less boilerplate per view op (and more stuff that can be codegen'd). Every `ViewMeta` object now contains a `forward` and `reverse` lambda, that knows how to replay the view and its inverse. This makes the actual code that executes the replaying logic a lot less boilerplate-y (see `Alias::sync_update_operations` and `FunctionalTensorWrapper::sync_`) (2) Every tensor during the functionalization pass is always wrapped in a `FunctionalTensorWrapper`. This is potentially unnecessary for Vulkan/XLA, and will have a mild perf impact, but for now this PR just targets the functorch use case. I previously had a complicated design a (`FunctionalTensorImplBase` class) to avoid needing the wrapper for XLA, but it had some subtleties that are gonna require more thought to fix, so I'm pushing that off for now. (3) `FunctionalTensorWrapper` objects accurately report stride information. It's a little annoying to do this though, because the logic that calculates stride info for each view isn't easily separated from the actual view kernels in core, `at::native::{view}`. I do this by adding logic in each `at::functionalization::{view}` kernel to call the reference implementation `at::native::{view}`. I don't do anything with the output aside from taking it's size/stride/storage_offset to set the actual output tensor's size/stride/storage_offset correctly. There's another annoying part to this: I'm pretty sure that we want to pass in the actual *wrapper* tensors directly into the native kernels, not their inner unwrapped values. But there are some `at::native::{view}` kernels that call other tensor methods, which re-invokes the dispatcher, calling functionalization/functorch kernels that try do the unwrapping. To do this, right now I have an `AutoDispatchDirectlyToNative` guard that basically ensures that any tensor methods called inside of the at::native::{view} op always redispatch straight to the CPU kernel (which will be another at::native:: kernel). This feels kind of heavy handed, but I'm not sure of a better way to do it. (4) `FunctionalTensorWrapper` objects accurately report aliasing information. There's a new `FunctionalStorageImpl` class (subclass of `StorageImpl`) that allows tensors in the functionalization pass to accurately alias storage. If two tensors `a` and `b` in a functionalized program are views of one another, then `a.storage.is_alias_of(b.storage)` should return true. I added this in a pretty similar way to how meta tensors allocate storage, although I don't pass in an actual allocator (I think this is fine because you should never resize a functional tensor's storage). One thing I'm not sure about - should `FunctionalTensorWrapper` set `storage_access_should_throw_`: (a) always, (b) never, (c) only if its wrapped tensor has it set. Right now I have it not set, mostly because calling the reference view functions (`at::native::{view}`) requires looking at the storage. But that means that if you try to access storage from python in a functionalized program, you'll get silent garbage instead of an error. Related question: are we planning on exposing meta tensor storage to python in the future (even though it contains garbage)? (5) better docs :) **View operator coverage** (6) The functionalization pass now gets math-composite view ops for free. I didn't add the `Functionalize` dispatch key to the composite set, because I don't want composite ops like `torch.ones` to get decomposed before hitting the functionalization pass. Instead, I added codegen to manually register the `at::native::` kernels of composite view ops. This is a little hairy, because the names of the `at::native::` kernels aren't easily accessible. They're stored in a `Dict[DispatchKey, BackendIndex]`. I made a best-effort attempt to get each view kernel's name, basically by assuming that every view op has either a composite or cpu implementation. There's also a hardcoded list of composite view ops in `gen_inplace_or_view_type.py`, but it looks like it's wrong. This is probably worth rationalizing later, but instead I created a new list of the "complete" set of composite view ops, and preserved the old set by hardcoding the delta between the two sets. (7) I've added codegen for ops that are both views AND mutations, like `transpose_()` (why do we even have these {emoji:1f622}). From some light testing, it looks like they work correctly with one caveat: I had a hard time ensuring that functorch programs that mutate their inputs using ops like `transpose_()` preserve the input mutations after the program finishes running. For (in my corresponding functorch branch) I emit a warning when this happens, and just don't preserve the mutation (8) I added `{view}_inverse` implementations for every view op, in `FunctionalInverses.cpp`. These are needed to take mutations made to views and replay them back onto the base. To reduce boilerplate, the codegen generates function declarations for each `{view}_inverse` function, so you get a nice compiler error when someone eventually adds a new view op. The only view ops currently not supported are (a) as_strided, and (b) the sparse view ops (values()/indices()). I can add support for as_strided, but it needs an `as_strided_inverse()` function. That will look really similar to the `as_strided_backward()` function in FunctionsManual.cpp, but it has some noticeable differences: we basically want an `as_strided_embed` for autograd and `as_strided_scatter` for functionalization. We also will probably need them to be primitives w.r.t to autograd, since the currently implementation for autograd uses view().copy_() calls that XLA won't be able to handle. I'm wondering if anyone has any objections, but otherwise I can make those change (which will require writing backward formulas for `as_strided_embed` and `as_strided_scatter`). I did a bunch of manual testing that all looks pretty good, but it's definitely not fully tested. Ed pointed out that once XLA uses this pass (or at least once there's a POC), we can just run the existing xla view test suite. Hopefully that delay is okay - if it's not, maybe we can think about using OpInfos similar to how functorch uses them for testing. Note: there's some duplication with autograd's view code. Every `{view}_inverse` implementation is really similar to the implementation for that view listed in `derivatives.yaml`. There are some major differences though: * the autograd implementations over those backwards functions (like `permute_backwards()`, in `FunctionsManual.cpp`) internally call other view ops. For functoinalization, we want them to (eventually call `{view}_copy` operators). * For view ops that take a subset of the original storage, like `slice/select/diagonal/as_strided()`, the autograd backward functions fill the "spaces" in the inverse call with zeroes. For functionalizations, we want to fill them with the value of `base` at those positions. It looks like this currently applies to 6 total ops (since we can ignore composites): * select * slice * diagonal * as_stridied * split * split_with_sizes A nice end state would probably be for the autograd + functoinalization codegen to both look at the same yaml (either `derivatives.yaml`, or something else), and automatically generate the right thing. I didn't leave that in scope for this PR though. **Current State + Next Steps** There are a bunch of followups after this PR eventually lands. Roughly in order: * Use the current pass to register problematic composite ops in functorch. Also, nested `functionalize()` calls aren't supported yet (I mostly just need to remove some debug asserts and test it). * Work on freeing up dispatch key space in the by deduplicating the `{backend}`/`Autograd{backend}`/`Sparse{backend}`/`Quantized{backend}` keys * Once we have more dispatch keys, split up this pass into 3 pieces - it's currently fused, and doesn't do the right thing for vulkan/XLA. Specifically, all of the `{view}` calls in the current pass's view-replay logic should turn into `{view}_copy` calls that vulkan/XLA know how to implement, and there will be separate passes for (a) removing mutations, and (b) turning `{view}_copy` calls back into `{view}` calls. For Vulkan, we eventually want a pass that ONLY removes aliasing and view calls, and doesn't remove mutations. We can also probably make the 2 new passes user dispatch keys to save dispatch key space, if they'll only be used by functorch anyway. * Do more of a dive on perf for the vulkan/xla use cases. There are several areas to improve perf with varying levels of effort required. The simplest one that I'll probably do regardless is to codegen the out-of-place kernels instead of using a boxed fallback. Getting a POC working for xla will also be useful to test the view operator coverage. **Example Codegen Output** View Op: ``` ::std::vector<at::Tensor> split_Tensor(c10::DispatchKeySet ks, const at::Tensor & self, int64_t split_size, int64_t dim) { auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self); ::std::vector<at::Tensor> out; { at::AutoDispatchBelowFunctionalize guard; auto tmp_output = at::redispatch::split(ks & c10::after_func_keyset, self_, split_size, dim); out = at::functionalization::impl::wrapFunctionalTensor(tmp_output); // I'm fusing the [alias removal], [mutation removal], [add views back] passes together. // Later, we'll want to turn them into separate passes (since e.g. vulkan only cares about alias removal). } at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( [split_size, dim](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor { return base.split(split_size, dim)[mutated_view_idx]; }, [split_size, dim](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor { return at::functionalization::impl::split_inverse(base, mutated_view, mutated_view_idx, split_size, dim); } ); at::functionalization::impl::set_view_meta(out, self, view_meta); at::AutoDispatchDirectlyToNative native_guard; ::std::vector<at::Tensor> reference_tensor_output = at::native::split(self, split_size, dim); at::functionalization::impl::set_strides(out, reference_tensor_output); return out; } ``` Mutation Op: ``` at::Tensor & add__Tensor(c10::DispatchKeySet ks, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { at::functionalization::impl::sync(self); at::functionalization::impl::sync(other); auto self_ = at::functionalization::impl::unwrapFunctionalTensor(self); auto other_ = at::functionalization::impl::unwrapFunctionalTensor(other); at::Tensor tmp_output; { at::AutoDispatchBelowFunctionalize guard; // The functionalization pass explicitly doesn't pass out= parameters to the redispatch tmp_output = at::redispatch::add( ks & c10::after_func_keyset, self_, other_, alpha); } self.replace_(tmp_output); at::functionalization::impl::maybe_add_update(self); return self; } ``` View + Mutation Op: ``` at::Tensor & transpose_(c10::DispatchKeySet ks, at::Tensor & self, int64_t dim0, int64_t dim1) { at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( [dim0, dim1](const at::Tensor& base, int64_t mutated_view_idx) -> at::Tensor { return base.transpose(dim0, dim1); }, [dim0, dim1](const at::Tensor& base, const at::Tensor& mutated_view, int64_t mutated_view_idx) -> at::Tensor { return at::functionalization::impl::transpose_inverse(base, mutated_view, dim0, dim1); } ); at::functionalization::impl::mutate_view_meta(self, view_meta); // See Note [Propagating strides in the functionalization pass] // Directly update the sizes/strides/storage_offset fields on self using the inplace call. // I need the guard because I don't want the at::native kernel to end up calling more functionalization/functorch kernels. // Its only job is to directly compute the output size/stride/storage_offset metadata. at::AutoDispatchDirectlyToNative native_guard; at::native::transpose_(self, dim0, dim1); return self; } ``` Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D31942093 Pulled By: bdhirsh fbshipit-source-id: b95598dae35dd1842fa8b1d8d1448332f3afaadf |
||
|
|
1d2ea76afb |
clamp: port to structured kernel (#61361)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61361 This PR ports the `clamp` kernel to the structured format. In addition, it introduces `OptionalScalarRef` as a replacement for `c10::optional<Scalar>&`. The latter, although it is a reference type, can still involve copying the contained `Scalar` (e.g. if the actual parameter is a `Scalar` or if a `c10::optional<Scalar>` is constructed just to call a kernel). `OptionalScalarRef` contains only a `const Scalar&`, and stores flag about whether the instance contains something inside the `Scalar` itself using a new tag. For more information, see #55070. Test Plan: Imported from OSS Reviewed By: albanD Differential Revision: D29821533 Pulled By: SplitInfinity fbshipit-source-id: 88d55df5a4b2c14b68a57e4905d90eea1b088d99 |
||
|
|
1c80b5220b |
nll_loss_forward: port to structured kernel (#61443)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61443 For more information, see #55070. This PR also adds a new type, `OptionalTensorRef` as a replacement for `c10::optional<Tensor>&` in order to avoid the reference count manipulations that are inevitable with the latter. I have confirmed using Godbolt/Compiler Explorer that this class does indeed avoid manipulating the reference count of the `intrusive_ptr` inside the `Tensor` it refers to: 1. [P429709479](https://www.internalfb.com/phabricator/paste/view/P429709479) - Given a `const Tensor&` in scope, an `OptionalTensorRef` can be constructed without bumping refcount. 2. [P429709883](https://www.internalfb.com/phabricator/paste/view/P429709883) - Given an `OptionalTensorRef`, a `const Tensor&` can be produced without bumping refcount. 3. [P429710335](https://www.internalfb.com/phabricator/paste/view/P429710335) - When `OptionalTensorRef` is destructed, the refcount should not be decremented. 4. [P429769525](https://www.internalfb.com/phabricator/paste/view/P429769525) - `OptionalTensorRef` can be assigned without refcount manipulation. 5. [P429769882](https://www.internalfb.com/phabricator/paste/view/P429769882) - `OptionalTensorRef` can be move assigned without refcount manipulation. Test Plan: Imported from OSS Reviewed By: jamesr66a Differential Revision: D29780666 Pulled By: SplitInfinity fbshipit-source-id: 7af157215300e9254d635433cbd583f7329fe064 |
||
|
|
eca98fedb5 |
split out NamedCType from CType. Remove direct string comparison from autograd codegen (#55334)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55334 The goal of this PR is to clean up some of the autograd codegen to compare C++ types using `CType` objects instead of raw strings. My last PR in the stack made that string comparison a little more fragile, since the raw C++ strings needed to be namespace-aware. I confirmed byte-for-byte no codegen changes vs. the last PR (which added namespaces to the codegen) by running `diff -qr ../pytorch-common_test/torch/csrc/autograd/generated/ ../pytorch-callgrind_test_after2/torch/csrc/autograd/generated/` and `diff -qr ../pytorch-common_test/build/aten/src/ATen/ ../pytorch-callgrind_test_after2/build/aten/src/ATen/` Note that a better end-state for the autograd codegen would be to do all of its type pattern matching directly off of JIT types, instead of off of CType’s (which are really just generated from JIT types, incorporating C++ specific semantics). That looks like it’ll require a pretty substantial change though, so I’m not doing it in this PR. As part of this change (and after talking with ezyang), I split off the `CType` data class into a separate `NamedCType` class, which holds a name and a `CType`. This way, `CType` only knows about actual C++ types, making it easier to compare CType’s to each other in the codegen when we only care about the type. The core change is in `types.py`, but it required a bunch of downstream changes to update all of the places where we create `CType`s to create `NamedCType`s instead. The main change in the autograd codegen was that I updated `SavedAttribute` to store a `NamedCType`. The other autograd changes all pretty much came from that change. Test Plan: Imported from OSS Reviewed By: bhosmer Differential Revision: D27708347 Pulled By: bdhirsh fbshipit-source-id: 3e07c80569c7b229c638f389e76e319bff6315f9 |
||
|
|
947c7a8215 |
add C++ namespacing logic to ctypes (#55047)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55047 Added namespaces to all of the `CTypes` printed in the codegen. This is pretty much required if we want to use codegen externally, since we can no longer assume that we're inside of the `at::` namespace. Important changes are in `types.py`. How do we add the notion of namespaces to C++ types without people having to write "at::Tensor" everywhere? Before this PR, `CType` held a raw string representing the type, i.e. `BaseCType("Tensor", binds)`. This PR introduces a set of singleton base C++ types in `types.py`, that know how to print their namespace. Instead, we'd write `BaseCType(tensorT, binds)`, where printing `tensorT` will properly print out "at::Tensor". This also means that you can't create arbitrary `CTypes`. If we need a new C++ type in the codegen, we need to add it to the list in `types.py`. One blip in the design: we don't want to change `RegistrationDeclarations.yaml`, since that'll break external backends that ingest it. I added separate functions to display types without the namespace that are used to create RegistrationDeclarations.yaml`. With an external codegen API though, we can eventually kill it :) I also didn't realize until this PR that `Declarations.yaml` is still directly in use, by some python/autograd codegen. Rather than keep that yaml byte-for-byte compatible, I just updated the callsites in the autograd codegen to work with namespaces. In the NEXT pr, I try to clean up some of the autograd codegen to stop using raw strings to match against C++ types. Test Plan: Imported from OSS Reviewed By: bhosmer Differential Revision: D27708349 Pulled By: bdhirsh fbshipit-source-id: 56a4f81fc101795bcb9ee1f722121480fb2356ad |
||
|
|
4753100a3b |
Un-ignore F403 in .flake8 (#55838)
Summary: Generally wildcard imports are bad for the reasons described here: https://www.flake8rules.com/rules/F403.html This PR replaces wildcard imports with an explicit list of imported items where possible, and adds a `# noqa: F403` comment in the other cases (mostly re-exports in `__init__.py` files). This is a prerequisite for https://github.com/pytorch/pytorch/issues/55816, because currently [`tools/codegen/dest/register_dispatch_key.py` simply fails if you sort its imports](https://github.com/pytorch/pytorch/actions/runs/742505908). Pull Request resolved: https://github.com/pytorch/pytorch/pull/55838 Test Plan: CI. You can also run `flake8` locally. Reviewed By: jbschlosser Differential Revision: D27724232 Pulled By: samestep fbshipit-source-id: 269fb09cb4168f8a51fd65bfaacc6cda7fb87c34 |
||
|
|
c5a67f1675 |
Fix minor inaccuracy in translate error reporting (#53032)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53032 Previously, you could get this error message: ``` Failed to synthesize the expression "Tensor & out". When I failed, the following bindings were available in the context: const Tensor & self; const Tensor & other; Scalar alpha; const Tensor & op.outputs_[0]; ``` There's a problem with this error message: it doesn't seem like there is any 'out' argument available, but actually there is: the last binding in the context is it. We printed the *expression*, not the *ctype name*. After this patch, the context now prints as: ``` const Tensor & self; // self const Tensor & other; // other Scalar alpha; // alpha const Tensor & out; // op.outputs_[0] ``` Now it becomes clear that it's a const mismatch. Maybe we could also beef up the error message so it points out near misses, but I'll leave that to future work. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: ljk53 Differential Revision: D26729768 Pulled By: ezyang fbshipit-source-id: adb363551a7145eac788943c20969c86b1f8a81b |
||
|
|
c9c4b871a5 |
[pytorch] reintroduce static dispatch (#51957)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51957 This is a simplified version of #51554. Compared to #51554, this version only supports statically dispatching to a specific backend. The benefit is that it skipped the dispatch key computation logic thus has less framework overhead. The downside is that if input tensors do not match the specified backend it will throw error instead of falling back to regular dispatch. Sample code: ``` Tensor empty(IntArrayRef size, TensorOptions options, c10::optional<MemoryFormat> memory_format) { return at::cpu::empty(size, options, memory_format); } // aten::conj(Tensor(a) self) -> Tensor(a) Tensor conj(const Tensor & self) { return at::math::conj(self); } // aten::conj.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) Tensor & conj_out(Tensor & out, const Tensor & self) { return at::cpu::conj_out(out, self); } // aten::conj.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) Tensor & conj_outf(const Tensor & self, Tensor & out) { return at::cpu::conj_out(out, self); } // aten::_conj(Tensor self) -> Tensor Tensor _conj(const Tensor & self) { return at::defaultbackend::_conj(self); } ``` For ops without the specific backend dispatch, it will throw error: ``` // aten::_use_cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank) -> bool bool _use_cudnn_ctc_loss(const Tensor & log_probs, const Tensor & targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t blank) { TORCH_CHECK(false, "Static dispatch does not support _use_cudnn_ctc_loss for CPU."); } ``` Differential Revision: D26337857 Test Plan: Imported from OSS Reviewed By: bhosmer Pulled By: ljk53 fbshipit-source-id: a8e95799115c349de3c09f04a26b01d21a679364 |
||
|
|
4d85e30133 |
Support at::cpu on non-structured kernels (#51590)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51590 This PR backports a subset of Jiakai's changes from https://github.com/pytorch/pytorch/pull/51554 that adds support for at::cpu in non-structured kernels. The unusual bits: - Need to add a new forward inference rule for doing conversions of const optional<Tensor>& to const Tensor& - Need to give the wrapper functions a prefix so that the call to wrapper is not ambiguous Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: ljk53 Differential Revision: D26209871 Pulled By: ezyang fbshipit-source-id: 8162686039675ab92a2af7a14f6b18941f8944df |
||
|
|
81c7c3bae5 |
Add api.structured; switch structured kernels to use const Tensor& everywhere (#51490)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51490 Mutable Tensor ref is a source of endless confusion for kernel writers; if we're going to make everyone rewrite their kernels, might as well also get rid of mutable Tensor& while we're at it. This is a refactor-then-small-update double whammy. The refactor is to separate tools.codegen.api.structured from api.native for describing the type signatures of structured kernels (previously, I was naughtily reusing native for this purpose--now I need it to behave differently as Tensor). This started off as a copy paste, but since there are not that many structured kernels so far I could delete all of the legacy logic from native that didn't make sense (without having to go out and fix all the use sites all at once). One more small addition was teaching translate to convert Tensor& to const Tensor&. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: bhosmer Differential Revision: D26182413 Pulled By: ezyang fbshipit-source-id: ed636866add3581179669cf9283f9835fcaddc06 |
||
|
|
648cdb7d0a |
Relax type signature for tools.codegen.api.translate (#51477)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51477 Passing in a full binding is still OK, but if you have less (e.g., an Expr/CType), that will do too. I'll need this for some codegen patches I'm doing later. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: bhosmer Differential Revision: D26179560 Pulled By: ezyang fbshipit-source-id: 5730dfb2c91bf5325496e57b0c91eb6823c9194d |
||
|
|
3efd5d8f01 |
Introduce tools.codegen.api.translate (#49122)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49122 cpparguments_exprs has induced a lot of head scratching in many recent PRs for how to structure the code in a good way. This PR eliminates the old algorithm for an entirely new algorithm inspired by logic programming. The net result is shorter, cleaner and should be more robust to future changes. This PR is a bit of a whopper. Here is the order to review it. - tools/codegen/api/types.py - Deleted CppArgument, CppArgumentPackIface (and subclasses), CppExpr, DispatcherExpr, DispatcherArgument, NativeExpr, NativeArgument, MetaArgument. All things previously called XArgument are now Binding. All things previously called XExpr are now Expr. I deleted the `__str__` implementation on Binding and fixed all call sites not to use it. On Binding, I renamed `str_no_default` and `str_default` to `defn` and `decl` for better symmetry with the corresponding signature concepts, although I'm open to naming them back to their original versions. - Obviously, things are less type safe without the class distinctions. So I introduce a new ADT called CType. CType represents the *semantic C++ type* of a binding: it is both the C++ type (e.g., `const Tensor&`) as well as the argument name that specifies what the binding denotes (e.g., `other`). Every binding now records its CType. The key observation here is that you don't actually care if a given expression is from the cpp or dispatcher or native API; what you care is having enough information to know what the expression means, so you can use it appropriately. CType has this information. For the most part, ArgNames are just the string names of the arguments as you see them in JIT schema, but there is one case (`possibly_redundant_memory_format`) where we encode a little extra information. Unlike the plain strings we previously used to represent C++ types, CType have a little bit of structure around optional and references, because the translation code needs to work around these concepts. - I took the opportunity to kill all of the private fields like `_arguments` and `_returns_type` (since the argument types don't make sense anymore). Everything is computed for you on the fly. If this is a perf problem in codegen we can start using `cached_property` decorator. - All of the heavy lifting in CppSignature.argument_packs has been moved to the cpp module. We'll head over there next. Similarly, all of the exprs methods are now calling translate, the new functionality which we haven't gotten to yet - tools/codegen/api/cpp.py - We refactor all of the type computation functions to return CType instead of str. Because CTypes need to know the denotation, there is a new `binds: ArgName` argument to most functions that provides the denotation, so we can slot it in. (An alternative would have been to construct CTypes without denotations and then fill them in post-facto, but I didn't do it this way. One downside is there are some places where I need a CType without denotation, so I fill these in with `__placeholder__` whenever this happens). - `argument` and `arguments` are now extremely simple. There is no more Pack business, just produce one or more Bindings. The one thing of note is that when both a `memory_format` and `options` are in scope, we label the memory format as `possibly_redundant_memory_format`. This will be used in translation - tools/codegen/api/dispatcher.py and tools/codegen/api/native.py - same deal as cpp.py. One thing is that `cpparguments_exprs` is deleted; that is in the translator - tools/codegen/api/translate.py - the translator! It uses a very simple backwards deduction engine to work out how to fill in the arguments of functions. There are comments in the file that explain how it works. - Everything else: just some small call site tweaks for places when I changed API. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: ljk53 Differential Revision: D25455887 Pulled By: ezyang fbshipit-source-id: 90dc58d420d4cc49281aa8647987c69f3ed42fa6 |