This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC https://github.com/pytorch/pytorch/issues/125683 for more background info.
1. Cpp template infrastructure
Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates.
2. Initial FP32 gemm template
This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction.
3. Correctness and performance
The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details.
Static shapes
| Benchmark | torchbench | huggingface | timm_models |
|------------|-------------|--------------|--------------|
| Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x |
| Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x |
| Single-threaded (baseline) | 1.56x | 1.19x | 1.51x |
| Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x |
Key models being sped up:
drq: 1.14x
soft_act: 1.12
cait_m36_384: 1.18x
Dynamic shapes
| Benchmark | torchbench | huggingface | timm_models |
| --- | --- | --- | --- |
| Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x |
| Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x |
| Single-threaded (baseline) | 1.55x | 1.20x | 1.51x |
| Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x |
Key models being sped up:
BERT_pytorch: 1.22x
pyhpc_turbulent: 1.13x
soft_actor_critic: 1.77x
BlenderbotForCausalLM: 1.09x
cait_m36_384: 1.17x
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124021
Approved by: https://github.com/jansel
Enable nonzero workspace and Cutlass StreamK for Inductor Cutlass GEMM ops.
This is a simpler rewrite of my original version of #119005 using @peterbell10 's workspace allocation mechanism from #117992
Test Plan:
- Additional unit test in test_cutlass_backend.py which specifically tests StreamK GEMM with workspace requirement
- CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125406
Approved by: https://github.com/jansel
This provides utilities for creating and querying properties on
sympy.Symbol. I want to use this refactor to get a better handle on how
the 's' prefix is being used in Inductor. To start, I only do
symbolic_shapes code because that's what I'm familiar with.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125395
Approved by: https://github.com/Skylion007
Fixes cutlass_utils.get_max_alignment() which was so far not checking the alignment properly. Basically
the method so far assumed that the passed layout is contiguous and row-major, which does not have to be true.
Test Plan:
CI - test_cutlass_backend.py to prevent regressions
Added unit test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124930
Approved by: https://github.com/int3
ghstack dependencies: #124929
Minor refactoring:
Remove unused "fused epilogue node" arguments from some method Kernel call signatures.
Test Plan:
Covered by current tests in test_cutlass_backend.py - no functional change.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124929
Approved by: https://github.com/eellison
This completely subsumes https://github.com/pytorch/pytorch/pull/120816
This makes use of the unbacked binding machinery to teach Inductor how to generate deferred runtime asserts directly. There is some back story about why I did it this way, let me explain.
Previously, our strategy for generating runtime asserts was that Dynamo would insert them into the FX graph after finishing tracing, and we would attempt to code generate them based on the FX graph. This is a good strategy for export, where we immediately export the graph. However, this strategy was afflicted by problems in eager, where we reuse the same ShapeEnv as before. In particular, on subsequent graph passes, we would immediately turn all of these assertions into noops, because when we evaluated their expressions, we would see that because we had a deferred runtime assert in the ShapeEnv, we know "oh, of course this expression is True" already. Oops!
So, with this PR, we take the attitude that as long as the ShapeEnv sticks around, the ShapeEnv's list of deferred runtime asserts is the source of truth, and we don't put anything in the graph. So we just need to decide when to actually generate asserts, and the place I picked was Inductor lowering, since we already have an AssertScalar buffer concept, and so I just need to insert them at this point. AssertScalar also uses raw sympy.Expr rather than SymInt/Bool, so it is easier to prevent unrestricted simplification at this point.
There are a few things jumbled together in this PR. I can split them if you want, but some of the changes are before I changed my strategy, but they're useful changes anyway.
**torch/_dynamo/output_graph.py** and **torch/_inductor/lowering.py** - Here, we stop putting deferred runtime asserts in the graph. I also have to make sure we don't DCE unused symbol arguments; we're going to get some goofy graph arguments this way, will be good to restore that optimization eventually. We also just disable codegen for `_assert_scalar` entirely; we assume that ShapeEnv will be good enough to capture all of these.
**torch/_inductor/codegen/wrapper.py** and **torch/_inductor/ir.py** - Add a way to codegen sizevars without forcing simplification
**torch/_inductor/graph.py** - The main logic. Our strategy is to interpose in the same place we are testing that unbacked SymInts are properly showing up in lowered code. The logic is directly analogous to the logic in the existing insert deferred runtime asserts FX pass, but it's simpler because sympy expressions can be directly stored on inductor IR nodes.
**torch/fx/experimental/symbolic_shapes.py** - For extra safety, we have a way of freezing runtime asserts, so that if you try to add more we error. This prevents us from adding runtime asserts after we've done lowering. There's a funny interaction with backwards which there's a comment for in graph.py
**torch/fx/passes/runtime_assert.py** - This is not really needed in this PR, but I rewrote the runtime assert logic to use unbacked_bindings rather than inferring it by looking for unbacked SymInts. Now, keypaths are translated into FX node acessors. Unfortunately, I couldn't delete the old inference code, because you still need it to find backed SymInts from arguments (as this pass may be used on graphs which don't explicitly bind all their shape variables as argments). There are some new tests exercising this.
TODO: I think we need to generate asserts for replacements too. This is a preexisting problem that the old FX pass had too.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124874
Approved by: https://github.com/jansel
ghstack dependencies: #124864
Improves the Cutlass backend GEMM template:
* Adds code which allows to create stand-alone test runners for Cutlass GEMM Kernels, which allows (manual) debugging of, for example, CUDA IMA errors or similar problems which occur in practice. Includes some utility code and tests to actually compile and run these standalone tests.
* Cleans up the GEMM template code through various refactorings
* Eliminates code sections and options that are unneccessary now that epilogue fusions are being removed.
* Limits the scope of a workaround for (flaky) Cutlass issues with bias broadcasting to neccessary cases.
* Puts some CPU runtime checks into #if / #endif blocks, such that it's possible to compile CUTLASS Kernels with lower CPU overhead.
* Add documentation comments
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124577
Approved by: https://github.com/jansel
ghstack dependencies: #124576
fixes#123039
In abi mode, ExternKernelSchedulerNode generates code using `aoti_torch_tensor_copy_` which requires `AtenTensorHandle`, but the allocation generates ArrayRefTensor to allocate mem in stack. To fix this issue, this PR prevents ExternKernelSchedulerNode from using stack-mem-allocation in abi, and creates AtenTensorHandle instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124037
Approved by: https://github.com/desertfire
When inductor generates triton code, the triton code can either assume that the inputs given to it are aligned or unaligned. If they are aligned, triton can use more efficient instructions (like vectorized loads or tensor cores). However, if we generate "aligned" code and pass in unaligned inputs, the triton code will error out; to fix this, we clone unaligned inputs that are passed to triton kernels that expect aligned inputs. This can lead to excessive clones if we have inputs that are not expected to be aligned.
In this PR, we use the example input to decide whether the generated triton code should assume alignment or not. If the example input is aligned, then we will generate triton code that assumes alignment; if at runtime we receive an unaligned input, we'll make a clone. Meanwhile, if the example input is not aligned, the generated triton code will not assume inputs are aligned and we won't ever need to clone.
Note that the alignment of the inputs is not guarded on; we found that adding guards on tensor offsets (a) was slow in cases where we do a lot of comparisons on tensor offsets, and (b) led to a lot of recompilations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123319
Approved by: https://github.com/eellison
This is a subset of changes extracted from https://github.com/pytorch/pytorch/pull/124683/
This PR contains modifications to make Inductor work with unbacked symbol inputs, which can occur when a data-dependent sized tensor is saved for backwards. The problems to be fixed:
* When binding initial symbols, we unconditionally bind unbacked symbols (instead of computing if they are needed, which only looks at backed symbols)
* Benchmark generation code doesn't work with unbacked symints as we have no hints to actually feed in real values. So I pick a random number and you are expected to fix it if it doesn't work
* Need to make sure we don't install dependencies on unbacked SymInt inputs, that puts us down the "promptly deallocate the input" path, but that's pointless for unbacked SymInt
Fixes https://github.com/pytorch/pytorch/issues/124652
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124739
Approved by: https://github.com/jansel
ghstack dependencies: #124310, #124314, #124316, #124394
This PR has a lot of "draw the rest of the fucking owl" energy. Here's how to break it down.
1. **torch/_inductor/graph.py** - We start by tightening unbacked symbol invariants. Specifically, as we lower FX nodes, we check whether or not every unbacked_binding recorded on the FX node meta, actually ends up getting bound (according to get_unbacked_symbol_defs) in all the buffers generated by the lowering. Hopefully this invariant is self evident. This leads to a lot of failures.
2. **torch/_inductor/ir.py** - Problem 1: There is softness in how Inductor computes defs of unbacked symbols in IR node. Previously, we tried to infer it by looking at the output sizes/strides/etc and see if new unbacked symbols popped up that we hadn't seen in the inputs. I don't know exactly what was buggy about the old code, but sometimes we would fail to notice an unbacked symbol had been bound, or rebind an unbacked symbol multiple times. Fortunately, thanks to the earlier PRs in our stack, we now have a nice list of unbacked symbol bindings from FX, so we now just store it directly on ExternKernel and use it directly to report defs. This has to be done twice: once for FallbackKernel (e.g., nonzero) and once for DynamicScalar (e.g., item) (see also **torch/_inductor/lowering.py**, **torch/_inductor/codegen/wrapper.py** and **torch/_inductor/codegen/cpp_wrapper_cpu.py** for the lowering and codegen changes for item)
* **process_kernel** - Sidequest! It turns out that Inductor lowering can reallocate unbacked symbols. This happens specifically when we repropagate fake tensors through the operator in `process_kernel`. This repropagation process is necessary because Inductor may have changed the strides of input tensors, and it must now recompute the strides so that it can continue to appropriately plan the rest of the lowering process. This is fine: we just make sure we do the rebind unbacked + compute_unbacked_bindings dance we've been doing previously in the PR stack. But instead of putting unbacked_bindings on a new FX node, they go straight into our unbacked_bindings on the Inductor IR node.
* **codegen_unbacked_symbol_defs** - Sidequest! FallbackKernel lowering is done in two steps. First, you emit the FallbackKernel buffer. Then, you emit MultiOutput buffers which actually give access to the individual outputs of FallbackKernel, which may have been multi-output. There is a design decision here: does the FallbackKernel bind the unbacked symbols, or the MultiOutput buffer? Historically, we put the binding on MultiOutput buffer, because it's more convenient: the FallbackKernel buffer is fake, in fact, it doesn't even get a name in C++ codegen. But it's kind of inconsistent with the keypath model that we've been tracking unbacked bindings with: if you have a multi-output node, you'd expect a keypath like `[0].size()[0]` representing the first output's first dimension size. That suggests that it's the FallbackKernel that should define the things. So that was my first implementation. Unfortunately, the C++ codegen is too cursed and I could not understand how to make it work in that case. So now we just unsoundly assume you cannot have multi-output data dependent output, and do the codegen in MultiOutput. There are some comments explaining exactly what we are improperly assuming.
3. **_rename_unbacked_to** in **torch/fx/experimental/symbolic_shapes.py** - Previously, when we renamed unbacked symbols, we clobbered any facts we previously knew about them. So for example, if we had a replacement `u0 -> s0` but then we renamed u0 to u1, we would now setup the replacement `u0 -> u1`, clobbering the old replacement. This apparently didn't matter in earlier PRs in the stack, but with Inductor now on the ball, there were some tests that indicated this was a problem. The solution is easy: if u0 had a preexisting replacement, reapply it to u1. However...
* **torch/_functorch/_aot_autograd/collect_metadata_analysis.py** - When we run forward analysis, this triggers fake tensor repropagation and fresh allocations. Previously, we just cleared out the pending symbols when finished the analysis. But with the change above, this would also migrate replacements to the new symbols... which are now dead. So now we explicitly suppress generation of these symbols with `ignore_fresh_unbacked_symbols` so that no rebinding happens at all.
* **torch/_dynamo/eval_frame.py** - same deal; I just searched for all sites we called clear() on pending
4. The last step is fixing the long tail of extra problems that show up, now that unbacked_bindings are load bearing into Inductor
* **torch/_dynamo/eval_frame.py** - Some of the exports are making copies of nodes without repropagating fake tensors, so in this case, it is important to also copy the `unbacked_bindings` (apparently this didn't matter before without the Inductor changes)
* **torch/_export/pass_base.py** - I discover that this is doing fake tensor repropagation via a test suite failure. Do the same playbook as AOTAutograd: PropagateUnbackedSymInts too! Actually, they also have implemented their own tracer as well, so do the same playbook as proxy_tensor: record unbacked_bindings on the newly traced nodes. UGH code duplication.
* **torch/_subclasses/fake_tensor.py**, **torch/_subclasses/fake_impls.py** (with call site updates at **torch/_functorch/_aot_autograd/traced_function_transforms.py** and **torch/fx/passes/fake_tensor_prop.py**) - What's this new epoch thing? I noticed that sometimes I would be retracing, call nonzero() on a fake tensor, and not allocate a new unbacked symbol. This is actually bad, because if I don't get a new unbacked symbol, I don't know there's a binding site, and `unbacked_bindings` is now missing a binding. The reason for this is memoization: if I reuse the exact same fake tensor on my retrace, it will already have an unbacked symint memoized on it and we will short circuit allocation. Well, that's no good. So I associate the memos with a fake tensor epoch, and every time you start a new fake tensor propagation from scratch, you bump the epoch so that I clear all the memos.
* **torch/_inductor/scheduler.py** - I notice in unit tests that V.current_node is not always set when we call process_kernel. So I save it into the IR node and restore it when we are running `get_estimated_runtime`.
* **torch/fx/experimental/symbolic_shapes.py** - A few things
* **rebind_unbacked** (re **_tensor_version**). Ordinarily, when you have an unbacked SymInt, you persistently hvae it all the way to the end of the program. `_tensor_version` violates this: this generates an unbacked SymInt (for reasons I don't quite understand?) and then gets rid of it later. This triggered an assert violation. I think this op is kind of misusing unbacked SymInt, but I didn't know how to refactor it, so it gets a special case.
* **rebind_unbacked** (re **Simplify SymBool binding**). Ugh, SymBool, what a pain in the butt. I have an assert that you can only rebind unbacked symbol to another unbacked symbol. This assert fails when a boolean is involved, because the result of running keypath on the result is not `u1`, it's `sympy.Piecewise(... sympy.Eq(u1, 1) ...)`. This is actually just `u1`, but Sympy doesn't know it because it doesn't know that `u1` value range is `[0, 1]`. So we manually implement the simplification needed to get the assert to pass.
* **compute_unbacked_bindings** (re **This is pretty fragile**). There is a really funny disaster involving memoization and Inductor process kernel. Ordinarily when I retrace, if there was a memo hit in the old trace, there will be a memo hit in the new trace. However, Inductor process kernel breaks this, because it recreates fake tensor inputs to the operator call from scratch (since they might have different strides), and obviously these tensor inputs don't have the memo from the old one. I tried a little bit to try to manually transplant the memo to the new fake tensor but it seemed hopeless, so I just let the fresh symbol ride, allocating a new unbacked symbol. However, in one of our tests, we rely on knowing that the first nonzero call is equal to the second (memoized) nonzero call. The equality test looked pretty easy to discharge, so I just went ahead and added a deferred runtime assert to this effect and it worked.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124394
Approved by: https://github.com/jansel
ghstack dependencies: #124310, #124314, #124316