Commit Graph

253 Commits

Author SHA1 Message Date
Jiong Gong
037615b989 [inductor][cpp] GEMM template (infra and fp32) (#124021)
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC https://github.com/pytorch/pytorch/issues/125683 for more background info.
1. Cpp template infrastructure
Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates.
2. Initial FP32 gemm template
This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction.
3. Correctness and performance
The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details.

Static shapes
| Benchmark | torchbench | huggingface | timm_models |
|------------|-------------|--------------|--------------|
| Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x |
| Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x |
| Single-threaded (baseline) | 1.56x | 1.19x | 1.51x |
| Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x |

Key models being sped up:
drq: 1.14x
soft_act: 1.12
cait_m36_384: 1.18x

Dynamic shapes
| Benchmark | torchbench | huggingface | timm_models |
| --- | --- | --- | --- |
| Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x |
| Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x |
| Single-threaded (baseline) | 1.55x | 1.20x | 1.51x |
| Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x |

Key models being sped up:
BERT_pytorch: 1.22x
pyhpc_turbulent: 1.13x
soft_actor_critic: 1.77x
BlenderbotForCausalLM: 1.09x
cait_m36_384: 1.17x

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124021
Approved by: https://github.com/jansel
2024-05-12 07:46:44 +00:00
lezcano
320af5eaa6 Compute bounds for the variables created during codegen (#123100)
Before we would just bail out on these bounds for all variables that did
not come from the FX graph. Now we propagate the bounds whenever we have
a rule for that op.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123100
Approved by: https://github.com/jgong5, https://github.com/peterbell10
2024-05-08 08:14:06 +00:00
PyTorch MergeBot
2a42c40791 Revert "Compute bounds for the variables created during codegen (#123100)"
This reverts commit bb668c6468.

Reverted https://github.com/pytorch/pytorch/pull/123100 on behalf of https://github.com/huydhn due to Sorry for reverting you change but it is failing inductor tests bb668c6468 ([comment](https://github.com/pytorch/pytorch/pull/123100#issuecomment-2096837821))
2024-05-06 20:23:39 +00:00
PyTorch MergeBot
7ffa5558ee Revert "[FX] Update type hints in torch.fx._compatibility.py (#125469)"
This reverts commit 235b4d6ec2.

Reverted https://github.com/pytorch/pytorch/pull/125469 on behalf of https://github.com/izaitsevfb due to breaks pyre in dependent projects (internal: see D56986361) ([comment](https://github.com/pytorch/pytorch/pull/125469#issuecomment-2096665396))
2024-05-06 18:36:43 +00:00
lezcano
bb668c6468 Compute bounds for the variables created during codegen (#123100)
Before we would just bail out on these bounds for all variables that did
not come from the FX graph. Now we propagate the bounds whenever we have
a rule for that op.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123100
Approved by: https://github.com/jgong5, https://github.com/peterbell10
2024-05-06 18:12:15 +00:00
Yifu Wang
58d8388ed3 Remove Inductor IRs for legacy functional collectives (#124992)
This PR completely removes the Inductor IR for legacy functional collectives:
- Removed the `CollectiveKernel` hiearchy and `Wait`, as well as the corresponding lowerings. These IRs are target (i.e. Python) specific and don't model node dependencies propoerly (e.g. they rely on `never_reuse_buffers` for correct behavior). They've been superceded by `ir._CollectiveKernel`.
- Removed `InPlaceHint` and the scheduler logic for handling it. `InPlaceHint` is a codegen-time buffer reuse mechanism controlled by the IR's codegen. It's a bit hacky and overlaps with the default buffer reuse mechanism. Removing it since it is only used by legacy functional collectives.
- Removed `OutputBuffer` and `MultiOutputNoSizeAssert` which are designed for and only used by legacy functional collectives.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124992
Approved by: https://github.com/Chillee, https://github.com/wanchaol
2024-05-05 19:49:58 +00:00
Xuehai Pan
235b4d6ec2 [FX] Update type hints in torch.fx._compatibility.py (#125469)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125469
Approved by: https://github.com/Skylion007
ghstack dependencies: #125468
2024-05-05 19:30:22 +00:00
Edward Z. Yang
6f70d22277 Extend torch.utils._sympy.symbol for more Inductor symbols (#125419)
I'm still missing a few, cdzq at least

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125419
Approved by: https://github.com/lezcano
ghstack dependencies: #125395
2024-05-04 09:05:00 +00:00
eellison
46f326eff5 explicitly reset stderr/stdout in precompilation (#125289)
I was seeing a weird bug where after running max-autotune my stdout would be misdirected. other people have not been able to repro this.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125289
Approved by: https://github.com/shunting314, https://github.com/mlazos
2024-05-01 23:41:36 +00:00
Sam Larsen
75a8e9ee77 [inductor] better cache clearing in fx graph cache tests (#125280)
Summary: There's a shortcoming in the FX graph cache tests in that they don't fully clear all inductor in-memory caches when testing the cache-hit path: We were previously accessing the FX graph cache correctly, but when loading the source object using the PyCodeCache.load_by_key_path() method, _that_ path was serving entries out of memory. To better mimic what happens during warm start (i.e., a new process), we should clear all in-memory caches.

Test Plan: updated the unit tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125280
Approved by: https://github.com/eellison
2024-05-01 04:47:46 +00:00
Yanbo Liang
7478b7f1ca Add common used score_mod functions for templated attention (#124670)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124670
Approved by: https://github.com/Chillee
2024-04-27 21:04:52 +00:00
Simon Fan
855939904b [cudagraphs] add more info to skip messages (#124700)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124700
Approved by: https://github.com/eellison
ghstack dependencies: #119729
2024-04-26 03:22:29 +00:00
David Berard
4259e5d0e0 [inductor] Specialize on unguarded alignment of example inputs (#123319)
When inductor generates triton code, the triton code can either assume that the inputs given to it are aligned or unaligned. If they are aligned, triton can use more efficient instructions (like vectorized loads or tensor cores). However, if we generate "aligned" code and pass in unaligned inputs, the triton code will error out; to fix this, we clone unaligned inputs that are passed to triton kernels that expect aligned inputs. This can lead to excessive clones if we have inputs that are not expected to be aligned.

In this PR, we use the example input to decide whether the generated triton code should assume alignment or not. If the example input is aligned, then we will generate triton code that assumes alignment; if at runtime we receive an unaligned input, we'll make a clone. Meanwhile, if the example input is not aligned, the generated triton code will not assume inputs are aligned and we won't ever need to clone.

Note that the alignment of the inputs is not guarded on; we found that adding guards on tensor offsets (a) was slow in cases where we do a lot of comparisons on tensor offsets, and (b) led to a lot of recompilations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123319
Approved by: https://github.com/eellison
2024-04-25 22:28:15 +00:00
PyTorch MergeBot
6a92b352ee Revert "[cudagraphs] add more info to skip messages (#124700)"
This reverts commit 0ed38c9b22.

Reverted https://github.com/pytorch/pytorch/pull/124700 on behalf of https://github.com/jeanschmidt due to one PR in this stack seems to have broken linux pull cuda12 tests ([comment](https://github.com/pytorch/pytorch/pull/119729#issuecomment-2076750595))
2024-04-25 09:26:25 +00:00
Simon Fan
0ed38c9b22 [cudagraphs] add more info to skip messages (#124700)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124700
Approved by: https://github.com/eellison
ghstack dependencies: #119729
2024-04-25 03:38:09 +00:00
Edward Z. Yang
660db767ef Don't clean up fresh inductor cache on error (#124620)
Useful for local debugging.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124620
Approved by: https://github.com/oulgen, https://github.com/desertfire, https://github.com/jansel
2024-04-23 02:13:05 +00:00
Bin Bao
bb37910e30 [AOTI] Fixes ScatterFallback codegen (#124580)
Summary: For https://github.com/pytorch/pytorch/issues/123184. ScatterFallback currently relies on op name matching for codegen, which makes its cpp codegen fragile. Refactor to use op_overload and fix the relevant unit test failures.

Differential Revision: [D56417815](https://our.internmc.facebook.com/intern/diff/D56417815)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124580
Approved by: https://github.com/chenyang78
2024-04-22 20:47:26 +00:00
Jason Ansel
7fd8870e6b [inductor] Refactor runtime files into torch._inductor.runtime (part 3) (#124557)
I am planning to make the compile_worker process not import torch so it can start up much faster.  This stack is prep for that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124557
Approved by: https://github.com/yanboliang
ghstack dependencies: #124552, #124553
2024-04-22 18:46:24 +00:00
Jason Ansel
bb8815bc31 [inductor] Refactor runtime files into torch._inductor.runtime (part 2) (#124553)
I am planning to make the compile_worker process not import torch so it can start up much faster.  This stack is prep for that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124553
Approved by: https://github.com/yanboliang
ghstack dependencies: #124552
2024-04-22 18:46:20 +00:00
PyTorch MergeBot
56714cb497 Revert "[inductor] Refactor runtime files into torch._inductor.runtime (part 2) (#124553)"
This reverts commit f4d47f5bbb.

Reverted https://github.com/pytorch/pytorch/pull/124553 on behalf of https://github.com/jeanschmidt due to There are internal breakages, already discussed with author and he'll FF ([comment](https://github.com/pytorch/pytorch/pull/124552#issuecomment-2070548223))
2024-04-22 18:28:05 +00:00
PyTorch MergeBot
0b90af0bf5 Revert "[inductor] Refactor runtime files into torch._inductor.runtime (part 3) (#124557)"
This reverts commit fcf28b0ad5.

Reverted https://github.com/pytorch/pytorch/pull/124557 on behalf of https://github.com/jeanschmidt due to There are internal breakages, already discussed with author and he'll FF ([comment](https://github.com/pytorch/pytorch/pull/124552#issuecomment-2070548223))
2024-04-22 18:28:05 +00:00
Jason Ansel
fcf28b0ad5 [inductor] Refactor runtime files into torch._inductor.runtime (part 3) (#124557)
I am planning to make the compile_worker process not import torch so it can start up much faster.  This stack is prep for that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124557
Approved by: https://github.com/yanboliang
ghstack dependencies: #124552, #124553
2024-04-22 04:51:15 +00:00
Jason Ansel
f4d47f5bbb [inductor] Refactor runtime files into torch._inductor.runtime (part 2) (#124553)
I am planning to make the compile_worker process not import torch so it can start up much faster.  This stack is prep for that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124553
Approved by: https://github.com/yanboliang
ghstack dependencies: #124552
2024-04-22 04:51:09 +00:00
Chen, Zejun
b1984237a0 [Profiler] Unify the device(CUDA, XPU, PrivateUse1) in torch profiler post processing (#123247)
This PR unifies the CUDA, XPU and PrivateUse1 in the torch profiler. Now CUDA, XPU and PrivateUse1 can together use string object `use_device` to distinguish each other and share one device path for calculating kineto time durations and memory statistics for post processing.

#suppress-api-compatibility-check

Co-authored-by: Aaron Enye Shi <enye.shi@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123247
Approved by: https://github.com/aaronenyeshi
2024-04-22 01:26:55 +00:00
Shunting Zhang
c5a4ba2257 [inductor] consider pointwise nodes when deciding reduction hint (#124131)
In certain **rare** scenarios, inductor can generate a reduction kernel with really bad perf. E.g., if
- the reduction kernel contains a reduction node followed by a pointwise node
- And the pointwise node use a transposed layout.
- the reduction node is an inner reduction
- and rnumel <= 1024 ,

then inductor will generate a persistent reduction kernel and it causes really bad perf when doing tl.store for the pointwise node since we use a very skinny tile `(XBLOCK=1, RBLOCK=next_power_of_2(rnumel))` .

I've tried a few version of fix.
- The first version is, if I found any pointwise node in a reduction kernel uses a non-contiguous dependency, we use ReductionHint.DEFAULT. This cause 8s compilation time increase for huggingface with no perf wins... The reason is ReductionHint.DEFAULT does more autotunings.
- Then I changed the code to be more specific. We change the hint from INNER to DEFAULT if we are sure that the pointwise kernel can use a >1 stride for the lowest dimension. Kernels meet this condition should mostly have really bad perf anyways.

The situation mentioned above is rare. But it's reported by internal users. I'll also run one more perf test.

Testing script: https://gist.github.com/shunting314/9d3389891fa43633b49b8b7564ad6d8b . Something equivalent is also added as a unit test.

For this specific test from user reports, we improve the mentioned reduction kernels perf by **4.14x** (451us -> 109us)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124131
Approved by: https://github.com/jansel
2024-04-20 05:07:56 +00:00
PyTorch MergeBot
520bc1080e Revert "[Profiler] Unify the device(CUDA, XPU, PrivateUse1) in torch profiler post processing (#123247)"
This reverts commit 768ce2cdda.

Reverted https://github.com/pytorch/pytorch/pull/123247 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](https://github.com/pytorch/pytorch/pull/123247#issuecomment-2066152611))
2024-04-19 09:09:03 +00:00
Chen, Zejun
768ce2cdda [Profiler] Unify the device(CUDA, XPU, PrivateUse1) in torch profiler post processing (#123247)
This PR unifies the CUDA, XPU and PrivateUse1 in the torch profiler. Now CUDA, XPU and PrivateUse1 can together use string object `use_device` to distinguish each other and share one device path for calculating kineto time durations and memory statistics for post processing.

#suppress-api-compatibility-check

Co-authored-by: Aaron Enye Shi <enye.shi@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123247
Approved by: https://github.com/aaronenyeshi, https://github.com/gujinghui
2024-04-19 03:31:13 +00:00
Pearu Peterson
43b4ac956e Add index_reduce decomposition (#122579)
As in the title.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122579
Approved by: https://github.com/peterbell10
ghstack dependencies: #123375
2024-04-18 01:30:47 +00:00
Xuehai Pan
93e249969b [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261)
Remove useless parentheses in `raise` statements if the exception type is raised with no argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261
Approved by: https://github.com/albanD
2024-04-17 19:29:34 +00:00
Sam Larsen
e5b404b809 [inductor] Fix fresh_inductor_cache() (#122661)
Summary: Modify fresh_inductor_cache() to clear cached state before mocking the toplevel cache_dir directory. Any lru_caches (or otherwise) can use the @clear_on_fresh_inductor_cache decorator to register the cache for clearing. Also change the base inductor TestCase class to use fresh_inductor_cache(). Previously that TestCase was only mocking the subdirectory within the toplevel cache dir designated for the FX graph cache artifacts.

Test Plan:
- New unit test
- All existing inductor tests will exercise fresh_inductor_cache()

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122661
Approved by: https://github.com/oulgen
2024-04-15 20:28:54 +00:00
Jason Ansel
6022600cc6 [inductor] Handle meta tensor ops in graph (#123786)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123786
Approved by: https://github.com/anijain2305
ghstack dependencies: #123700, #123705
2024-04-12 19:03:13 +00:00
PyTorch MergeBot
d994d993c0 Revert "[inductor] Fix fresh_inductor_cache() (#122661)"
This reverts commit cda383e7bc.

Reverted https://github.com/pytorch/pytorch/pull/122661 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/122661#issuecomment-2051171028))
2024-04-12 07:26:50 +00:00
Sam Larsen
cda383e7bc [inductor] Fix fresh_inductor_cache() (#122661)
Summary: Modify fresh_inductor_cache() to clear cached state before mocking the toplevel cache_dir directory. Any lru_caches (or otherwise) can use the @clear_on_fresh_inductor_cache decorator to register the cache for clearing. Also change the base inductor TestCase class to use fresh_inductor_cache(). Previously that TestCase was only mocking the subdirectory within the toplevel cache dir designated for the FX graph cache artifacts.

Test Plan:
- New unit test
- All existing inductor tests will exercise fresh_inductor_cache()

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122661
Approved by: https://github.com/oulgen
2024-04-10 20:38:56 +00:00
xinan.lin
9743e3a19c [Inductor Intel GPU backend Upstream] Add Inductor Intel GPU backend. (#121895)
As the design in RFC https://github.com/pytorch/pytorch/issues/114856, this PR implemented Intel GPU Inductor backend by:
- Reuse WrapperCodegen and TritonScheduling for python wrapper and kernel code generation. And implenented device-specific code generation in XPUDeviceOpOverrides
- Reuse fx_pass, lowering, codecache, triton kernel auto-tuning, and compilation.

For the test case, this PR provided test/inductor/test_xpu_basic.py for basic inductor backend functionality testing.
We'll reuse all the existing Inductor test case in the next PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121895
Approved by: https://github.com/EikanWang, https://github.com/jansel, https://github.com/desertfire
2024-04-05 09:05:11 +00:00
PyTorch MergeBot
a808559fc6 Revert "[inductor] Fix fresh_inductor_cache() (#122661)"
This reverts commit ba7d396eb7.

Reverted https://github.com/pytorch/pytorch/pull/122661 on behalf of https://github.com/clee2000 due to new test is failing internally ([comment](https://github.com/pytorch/pytorch/pull/122661#issuecomment-2037977934))
2024-04-04 18:55:55 +00:00
Sam Larsen
ba7d396eb7 [inductor] Fix fresh_inductor_cache() (#122661)
Summary: Modify fresh_inductor_cache() to clear cached state before mocking the toplevel cache_dir directory. Any lru_caches (or otherwise) can use the @clear_on_fresh_inductor_cache decorator to register the cache for clearing. Also change the base inductor TestCase class to use fresh_inductor_cache(). Previously that TestCase was only mocking the subdirectory within the toplevel cache dir designated for the FX graph cache artifacts.

Test Plan:
- New unit test
- All existing inductor tests will exercise fresh_inductor_cache()

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122661
Approved by: https://github.com/oulgen
2024-04-04 02:32:37 +00:00
Kai Londenberg
74b3a7920e [Inductor Cutlass backend] GEMM size threshold for Cutlass backend usage (#121491)
* Adds a configurable GEMM size threshold for the usage of Cutlass GEMM Kernels **_inductor.config.cutlass_backend_min_gemm_size**

 * During GEMM algorithm choice generation: **if no viable choices can be generated using the configured backends, the ATen backend will be used as a fallback backend**, even if it is not enabled in **_inductor.config.max_autotune_gemm_backends**

Test plan:
CI
Additional unit test in test_cutlass_backend.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121491
Approved by: https://github.com/jansel
ghstack dependencies: #121490
2024-04-03 13:34:16 +00:00
Bin Bao
0ff6155eee [AOTI] Support module buffer mutation (#123164)
Summary: Fixes https://github.com/pytorch/pytorch/issues/120424. Because in a forward pass module buffers may be mutated, we need to allow that in AOTI. In addition, this will be a necessary step if we want to extend AOTI to training.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123164
Approved by: https://github.com/digantdesai, https://github.com/malfet, https://github.com/chenyang78, https://github.com/khabinov
2024-04-02 20:25:26 +00:00
eellison
5f46312dbb Reapply "Switch cudagraph backend to cudagraph trees (#121019)" and "Add Cudagraphs disable checking (#121018)" (#121864) (#122713)
This reverts commit 92ed8553a6.

No longer importing codecache or boxed_nop at top level, both of which casued issues.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122713
Approved by: https://github.com/anijain2305
2024-04-02 16:11:00 +00:00
Merlin Lüdicke
fdc281f258 [inductor] lower min SM requirement for gemm autotuning to 68 (#123121)
Lower the minimum number of CUDA SMs required for GEMM autotuning from V100 to 3080 level, allowing some high-end consumer GPUs to benefit as well.

Fixes #109489

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123121
Approved by: https://github.com/jansel
2024-04-02 00:28:59 +00:00
Jiong Gong
49121603ab [inductor][cpp] support vectorized indirect indexing (#119655)
This PR adds the vectorized indirect indexing so that we can further simplify the `CppVecKernelChecker` (done in the later PR #119734) and remove the check that throws `CppVecUnsupportedError`. A boundary assertion check is added on vectorized indices and via the new `indirect_assert` method on `Kernel` - the base implementation is for scalar indices, overridden in `CppVecKernel` for vectorized indices.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119655
Approved by: https://github.com/jansel
ghstack dependencies: #119654
2024-03-27 10:25:45 +00:00
Matthew Haddock
50036ec781 [Inductor] Add a test for creating a cpu inductor-> triton backend (#122396)
Summary: Currently there is a test for adding a backend in test/inductor/test_extension_backend.py for a cpp backend with a new device. However there is no such test for the Triton backend; it should be possible for a user to create and register your own ExtensionWrapperCodegen and ExtensionSchedulingfor another non-CUDA device and be able to generate Triton code. For simplicity I have chosen to use a CPU device, as I think it's plausible someone might want to create a CPU Triton backend.

Unfortunately the generation and running of the code is quite tightly coupled so I've had to use a mocked function to extract the code before running. Suggestions are welcome for better ways to do this.

This is a stepping off point for some additional PRs to make the Triton code path less CUDA specific, as currently there would be no way to test this avenue.

Test plan:
```
frames [('total', 1), ('ok', 1)]
stats [('calls_captured', 3), ('unique_graphs', 1)]
inductor [('intermediate_hooks', 1)]
aot_autograd [('total', 1), ('ok', 1)]
.
----------------------------------------------------------------------
Ran 1 test in 0.394s
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122396
Approved by: https://github.com/jansel
2024-03-23 01:14:57 +00:00
PyTorch MergeBot
97d3bf71b9 Revert "[Inductor Cutlass backend] GEMM size threshold for Cutlass backend usage (#121491)"
This reverts commit 700c92e1b9.

Reverted https://github.com/pytorch/pytorch/pull/121491 on behalf of https://github.com/huydhn due to Sorry for reverting you change but I think it is failing on ROCm, i.e. 700c92e1b9 ([comment](https://github.com/pytorch/pytorch/pull/121490#issuecomment-2015829464))
2024-03-22 20:11:47 +00:00
Kefei Lu
400cc518fc pt2 dper passes: run shape prop before each pass (#122451)
Summary: Most passes relies on shape info. We need to run shape prop after each pass

Reviewed By: frank-wei

Differential Revision: D55221119

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122451
Approved by: https://github.com/frank-wei
2024-03-22 17:57:25 +00:00
Kai Londenberg
700c92e1b9 [Inductor Cutlass backend] GEMM size threshold for Cutlass backend usage (#121491)
* Adds a configurable GEMM size threshold for the usage of Cutlass GEMM Kernels **_inductor.config.cutlass_backend_min_gemm_size**

 * During GEMM algorithm choice generation: **if no viable choices can be generated using the configured backends, the ATen backend will be used as a fallback backend**, even if it is not enabled in **_inductor.config.max_autotune_gemm_backends**

Test plan:
CI
Additional unit test in test_cutlass_backend.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121491
Approved by: https://github.com/jansel
ghstack dependencies: #121490
2024-03-22 10:58:43 +00:00
haozhe.zhu
3bc2bb6781 use two pass reduction for deterministic reduction order (#115620)
## Motivation
Address the [non-deterministic reduction order](https://github.com/pytorch/pytorch/issues/93542#issuecomment-1411294181) issue for `omp parallel reduction`.

## Latest update on 1.15:
55d81901bc.
Do not reduce to arr in loops. Instead, reduce to a local scaler and write it to arr after local reduction is done. This will allow the compiler to optimize the reduction variable in register instead read/write from memory. If the `working set` of `loop body` is quite large, `read/write from register/memory` will have a large gap.
```
vaddss (%xmm0, %xmm11, %xmm11) -> accumulate in register %xmm0
vaddssl ((%rdx, %rdi, 4), %xmm0, %xmm0) -> accumulate in memory address (%rdx, %rdi, 4)
```
Examples code:
```
tmp0_acc_arr[64];
#pragma omp parallel num_threads(64)
{
    auto tid = omp_get_thread_num();
    #pragma omp for
    for(...){
        ....
        tmp0_acc_arr[tid] = tmp0_acc_arr[tid] + tmp_x;  // access array will always from memory
    }
}
```
will be changed to
```
tmp0_acc_arr[64];
#pragma omp parallel num_threads(64)
{
    auto tid = omp_get_thread_num();
    **auto tmp0_acc_local = 0;**
    #pragma omp for
    for(...){
        ....
        **tmp0_acc_local**  = tmp0_acc_local + tmp_x;
    }
    **tmp0_acc_arr[tid] = tmp0_acc_local;**
}
```

## Descriptions
Following aten to use `two pass reduction` with `omp parallel` for deterministic reduction order.
9c3ae37fc4/aten/src/ATen/Parallel-inl.h (L39)
9c3ae37fc4/aten/src/ATen/native/TensorIteratorReduce.cpp (L24)
```
            float tmp_acc0 = 0;
            at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
            // init reduction buffer per thread
            float tmp_acc0_arr[64];
            at::vec::Vectorized<float> tmp_acc0_vec_arr[64];
            for (int tid = 0; tid < 64; tid++)
            {
                tmp_acc0_arr[tid] = 0;
                tmp_acc0_vec_arr[tid] = at::vec::Vectorized<float>(0);
            }
            #pragma omp parallel num_threads(64)
            {
                int tid = omp_get_thread_num();
                #pragma omp for
                for(long x0=static_cast<long>(0L); x0<static_cast<long>(3964928L); x0+=static_cast<long>(16L))
                {
                    auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x0));
                    auto tmp1 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x0));
                    auto tmp2 = tmp0 - tmp1;
                    auto tmp3 = tmp2 * tmp2;
                    // reduce to per thread buffers
                    tmp_acc0_vec_arr[tid] = tmp_acc0_vec_arr[tid] + tmp3;
                }
            }
            // second pass reduce
            for (int tid = 0; tid < 64; tid++)
            {
                tmp_acc0 = tmp_acc0 + tmp_acc0_arr[tid];
                tmp_acc0_vec = tmp_acc0_vec + tmp_acc0_vec_arr[tid];
            }
            tmp_acc0 = tmp_acc0 + at::vec::vec_reduce_all<float>([](at::vec::Vectorized<float>& x, at::vec::Vectorized<float>& y) { return x + y; }, tmp_acc0_vec);
            out_ptr0[static_cast<long>(0L)] = static_cast<float>(tmp_acc0);
```

## Test results
I test this PR with dynamo benchmark on 32-core ICX system,
Result (avg speed up):
| |  before this PR   | after this PR  |
| ---- |  ----  | ----  |
| torchbench | 1.303  | 1.301 |
| hugginface | 1.346  | 1.343 |
| timms | 1.971 | 1.970 |

```
export LD_PRELOAD=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}/lib/libiomp5.so:${CONDA_PREFIX:-"$(dirname $(which conda))/../"}/lib/libjemalloc.so
export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1"
export KMP_AFFINITY=granularity=fine,compact,1,0
export KMP_BLOCKTIME=1

multi_threads_test() {
    CORES=$(lscpu | grep Core | awk '{print $4}')
    export OMP_NUM_THREADS=$CORES
    end_core=$(expr $CORES - 1)
    numactl -C 0-${end_core} --membind=0 python benchmarks/dynamo/${SUITE}.py --${SCENARIO} --${DT} -dcpu -n50 --no-skip --dashboard --only "${MODEL}" ${Channels_extra} ${BS_extra} ${Shape_extra} ${Mode_extra} ${Wrapper_extra} ${Flag_extra} --timeout 9000 --backend=inductor --output=${LOG_BASE}/${SUITE}.csv
}

SCENARIO=performance
DT=float32
export TORCHINDUCTOR_FREEZING=1
Flag_extra="--freezing"
Mode_extra="--inference"

for suite in timm_models huggingface torchbench
do
  export SUITE=$suite
  echo $SUITE
  export LOG_BASE=`date +%m%d%H%M%S`
  mkdir $LOG_BASE
  multi_threads_test
done
```
System info
```
ubuntu@ip-172-31-18-205:~/hz/pytorch$ lscpu
Architecture:            x86_64
  CPU op-mode(s):        32-bit, 64-bit
  Address sizes:         46 bits physical, 48 bits virtual
  Byte Order:            Little Endian
CPU(s):                  64
  On-line CPU(s) list:   0-63
Vendor ID:               GenuineIntel
  Model name:            Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
    CPU family:          6
    Model:               106
    Thread(s) per core:  2
    Core(s) per socket:  32
    Socket(s):           1
    Stepping:            6
    BogoMIPS:            5800.00
    Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic mo
                         vbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xs
                         aveopt xsavec xgetbv1 xsaves wbnoinvd ida arat avx512vbmi pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear flush_l1d arch_capabilities
Virtualization features:
  Hypervisor vendor:     KVM
  Virtualization type:   full
Caches (sum of all):
  L1d:                   1.5 MiB (32 instances)
  L1i:                   1 MiB (32 instances)
  L2:                    40 MiB (32 instances)
  L3:                    54 MiB (1 instance)
NUMA:
  NUMA node(s):          1
  NUMA node0 CPU(s):     0-63
Vulnerabilities:
  Gather data sampling:  Unknown: Dependent on hypervisor status
  Itlb multihit:         Not affected
  L1tf:                  Not affected
  Mds:                   Not affected
  Meltdown:              Not affected
  Mmio stale data:       Mitigation; Clear CPU buffers; SMT Host state unknown
  Retbleed:              Not affected
  Spec rstack overflow:  Not affected
  Spec store bypass:     Mitigation; Speculative Store Bypass disabled via prctl
  Spectre v1:            Mitigation; usercopy/swapgs barriers and __user pointer sanitization
  Spectre v2:            Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
  Srbds:                 Not affected
  Tsx async abort:       Not affected
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115620
Approved by: https://github.com/jgong5, https://github.com/jansel
2024-03-15 02:03:10 +00:00
Animesh Jain
92ed8553a6 Revert "Switch cudagraph backend to cudagraph trees (#121019)" and "Add Cudagraphs disable checking (#121018)" (#121864)
This reverts commit 9373ad0bb8.

Revert "Add Cudagraphs disable checking (#121018)"

This reverts commit 4af0e634bf.

Causes compilation time increase.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121864
Approved by: https://github.com/eellison
2024-03-15 00:03:09 +00:00
Aleksandar Samardžić
1251f0fa31 Add CUTLASS kernel as choice for _int_mm() Inductor autotuning (#119685)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119685
Approved by: https://github.com/cpuhrsch, https://github.com/kadeng
2024-03-14 13:25:23 +00:00
eellison
6ca9ae4f86 Express y grid > 2^16 in terms of z grid (#121554)
CUDA has a max y_grid of 65535. If we're computing larger than that we can compose it in terms of z grid, which is currently unused in inductor codegen.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121554
Approved by: https://github.com/aakhundov
2024-03-12 02:36:19 +00:00
Peter Bell
168a04e752 [inductor] Changes to support newer triton pin (#121267)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121267
Approved by: https://github.com/lezcano
ghstack dependencies: #121438
2024-03-09 18:17:36 +00:00